-
Notifications
You must be signed in to change notification settings - Fork 80
Encapsulate communication lowering variables into a struct #6037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,13 @@ namespace nvfuser { | |
|
|
||
| namespace { | ||
|
|
||
| struct CommunicationLoweringParams { | ||
| CommunicatorBackend backend{}; | ||
| DeviceIdxType my_device_idx{}; | ||
| Val* host_loop_index{}; | ||
| std::optional<BinaryOpType> reduction_op; | ||
| }; | ||
|
|
||
| // TODO: handle `c10d::RedOpType::reduceOp::AVG` and | ||
| // `c10d::RedOpType::reduceOp::PREMUL_SUM` | ||
| c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { | ||
|
|
@@ -55,7 +62,7 @@ c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { | |
| void lowerToScatter( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
| NVF_ERROR_EQ( | ||
|
|
@@ -83,7 +90,7 @@ void lowerToScatter( | |
| team, | ||
| getRelativeIndex(team, root), | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
|
|
||
| // Adds zero or multiple Gather communications to the vector 'comms' | ||
|
|
@@ -93,7 +100,7 @@ void lowerToScatter( | |
| void lowerToGather( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| // we create as many 'Gathers' as there are devices in the receiver mesh | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
|
|
@@ -114,27 +121,26 @@ void lowerToGather( | |
| team, | ||
| getRelativeIndex(team, root), | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
| } | ||
|
|
||
| // Add one or zero Allgather communication to the vector 'comms' | ||
| void lowerToAllgather( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| std::vector<Expr*>& comms, | ||
| DeviceIdxType my_device_idx) { | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| const DeviceMesh& mesh = input_tv->getDeviceMesh(); | ||
| Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); | ||
| Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx); | ||
| comms.push_back(IrBuilder::create<Communication>( | ||
| CommunicationType::Allgather, | ||
| output_tv, | ||
| input_tv, | ||
| team, | ||
| /*root=*/-1, | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
|
|
||
| // Either of the following cases is happening: | ||
|
|
@@ -144,13 +150,13 @@ void lowerToAllgather( | |
| void lowerToBroadcast( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| Val* root, | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
|
|
||
| Team team = receiver_mesh.vector(); | ||
| Val* root = params.host_loop_index; | ||
|
|
||
| if (sender_mesh == receiver_mesh) { | ||
| NVF_ERROR( | ||
|
|
@@ -175,7 +181,7 @@ void lowerToBroadcast( | |
| team, | ||
| root, | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
|
|
||
| // Adds several SendRecv communications to the vector 'comms' | ||
|
|
@@ -185,7 +191,7 @@ void lowerToBroadcast( | |
| void lowerToSendRecv( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
|
|
@@ -214,16 +220,18 @@ void lowerToSendRecv( | |
| team, | ||
| /*root=*/getRelativeIndex(team, sender), | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
| } | ||
|
|
||
| void lowerToReduce( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| BinaryOpType op_type, | ||
| const CommunicatorBackend backend, | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| NVF_ERROR( | ||
| params.reduction_op.has_value(), | ||
| "Reduce communication requires reduction_op in params"); | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
| NVF_ERROR_EQ( | ||
|
|
@@ -236,7 +244,8 @@ void lowerToReduce( | |
| 1, | ||
| "Reduce only supported a 1D mesh. Given ", | ||
| receiver_mesh); | ||
| const auto reduce_op_type = getC10dReduceOpType(op_type); | ||
| const auto reduce_op_type = | ||
| getC10dReduceOpType(valueOrError(params.reduction_op)); | ||
| // we create as many Reduces as there are devices in the receiver mesh | ||
| for (auto root : receiver_mesh.vector()) { | ||
| Team team = sender_mesh.vector(); | ||
|
|
@@ -250,36 +259,38 @@ void lowerToReduce( | |
| team, | ||
| getRelativeIndex(team, root), | ||
| reduce_op_type, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
| } | ||
|
|
||
| void lowerToAllreduce( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| BinaryOpType op_type, | ||
| const CommunicatorBackend backend, | ||
| std::vector<Expr*>& comms, | ||
| DeviceIdxType my_device_idx) { | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| NVF_ERROR( | ||
| params.reduction_op.has_value(), | ||
| "Allreduce communication requires reduction_op in params"); | ||
| const DeviceMesh& mesh = input_tv->getDeviceMesh(); | ||
| Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); | ||
| Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx); | ||
| comms.push_back(IrBuilder::create<Communication>( | ||
| CommunicationType::Allreduce, | ||
| output_tv, | ||
| input_tv, | ||
| team, | ||
| /*root=*/-1, | ||
| getC10dReduceOpType(op_type), | ||
| backend)); | ||
| getC10dReduceOpType(valueOrError(params.reduction_op)), | ||
| params.backend)); | ||
| } | ||
|
|
||
| void lowerToReduceScatter( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| BinaryOpType op_type, | ||
| const CommunicatorBackend backend, | ||
| std::vector<Expr*>& comms, | ||
| DeviceIdxType my_device_idx) { | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| NVF_ERROR( | ||
| params.reduction_op.has_value(), | ||
| "ReduceScatter communication requires reduction_op in params"); | ||
| NVF_ERROR_EQ( | ||
| input_tv->getDeviceMesh(), | ||
| output_tv->getDeviceMesh(), | ||
|
|
@@ -288,22 +299,22 @@ void lowerToReduceScatter( | |
| "Insert a Set operation before or after the reduction to reshard to " | ||
| "another device mesh"); | ||
| const DeviceMesh& mesh = input_tv->getDeviceMesh(); | ||
| Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); | ||
| Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx); | ||
|
|
||
| comms.push_back(IrBuilder::create<Communication>( | ||
| CommunicationType::ReduceScatter, | ||
| output_tv, | ||
| input_tv, | ||
| /*team=*/team, | ||
| /*root=*/-1, | ||
| getC10dReduceOpType(op_type), | ||
| backend)); | ||
| getC10dReduceOpType(valueOrError(params.reduction_op)), | ||
| params.backend)); | ||
| } | ||
|
|
||
| void lowerToAllToAll( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
|
|
@@ -331,16 +342,14 @@ void lowerToAllToAll( | |
| sender_mesh.vector(), | ||
| /*root=*/-1, | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
| params.backend)); | ||
| } | ||
|
|
||
| void lowerToCollectivePermute( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| std::vector<Expr*>& comms, | ||
| Val* host_loop_index, | ||
| DeviceIdxType my_device_idx) { | ||
| const CommunicationLoweringParams& params, | ||
| std::vector<Expr*>& comms) { | ||
| NVF_ERROR_EQ( | ||
| input_tv->getDeviceMesh(), | ||
| output_tv->getDeviceMesh(), | ||
|
|
@@ -350,7 +359,7 @@ void lowerToCollectivePermute( | |
| output_tv->getDeviceMesh()); | ||
|
|
||
| NVF_ERROR( | ||
| host_loop_index != nullptr, | ||
| params.host_loop_index != nullptr, | ||
| "Host loop index must be provided for CollectivePermute."); | ||
|
|
||
| IterDomain* stream_id = | ||
|
|
@@ -362,10 +371,13 @@ void lowerToCollectivePermute( | |
| ParallelType pt = swizzle->parallelType(); | ||
|
|
||
| const auto& [recv_peer, send_peer] = dispatchSwizzle1D( | ||
| host_loop_index, my_device_idx, pt, input_tv->getDeviceMesh()); | ||
| params.host_loop_index, | ||
| params.my_device_idx, | ||
| pt, | ||
| input_tv->getDeviceMesh()); | ||
| Team team = input_tv->getDeviceMesh().vector(); | ||
| comms.push_back(IrBuilder::create<CollectivePermute>( | ||
| output_tv, input_tv, team, send_peer, recv_peer, backend)); | ||
| output_tv, input_tv, team, send_peer, recv_peer, params.backend)); | ||
| } | ||
|
|
||
| IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { | ||
|
|
@@ -657,52 +669,52 @@ std::vector<Expr*> convertSingleOpToCommunication( | |
| "Expected communication info for resharding expr: ", | ||
| e); | ||
|
|
||
| auto op_type = [](Expr* e) -> BinaryOpType { | ||
| auto op_type = [](Expr* e) -> std::optional<BinaryOpType> { | ||
| if (auto* reduce = dynamic_cast<ReductionOp*>(e)) { | ||
| return reduce->getReductionOpType(); | ||
| } | ||
|
|
||
| NVF_ERROR(e != nullptr); | ||
| if (e->isA<SqueezeOp>()) { | ||
| if (e != nullptr && e->isA<SqueezeOp>()) { | ||
| return BinaryOpType::Add; | ||
| } | ||
|
|
||
| NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e); | ||
| return std::nullopt; | ||
| }; | ||
|
|
||
| CommunicationLoweringParams params{ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good. You may want to consider the builder design pattern used by at::TensorOptions: https://docs.pytorch.org/cppdocs/notes/tensor_creation.html#configuring-properties-of-the-tensor. It's a more fluent way to build the parameter without having to mechanically create a struct and overwrite fields.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, I prefer the struct. All the fields need to be overwritten (even if I set the common values as default). If a case arises where some fields are always going to be these default values, then I can consider switching the API |
||
| .backend = backend, | ||
| .my_device_idx = my_device_idx, | ||
| .host_loop_index = host_loop_index, | ||
| .reduction_op = op_type(e)}; | ||
|
|
||
| switch (communication_info->type) { | ||
| case CommunicationType::Scatter: | ||
| lowerToScatter(input_tv, output_tv, backend, comms); | ||
| lowerToScatter(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::Gather: | ||
| lowerToGather(input_tv, output_tv, backend, comms); | ||
| lowerToGather(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::Allgather: | ||
| lowerToAllgather(input_tv, output_tv, backend, comms, my_device_idx); | ||
| lowerToAllgather(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::Broadcast: | ||
| lowerToBroadcast(input_tv, output_tv, backend, host_loop_index, comms); | ||
| lowerToBroadcast(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::SendRecv: | ||
| lowerToSendRecv(input_tv, output_tv, backend, comms); | ||
| lowerToSendRecv(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::ReduceScatter: | ||
| lowerToReduceScatter( | ||
| input_tv, output_tv, op_type(e), backend, comms, my_device_idx); | ||
| lowerToReduceScatter(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::Allreduce: | ||
| lowerToAllreduce( | ||
| input_tv, output_tv, op_type(e), backend, comms, my_device_idx); | ||
| lowerToAllreduce(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::Reduce: | ||
| lowerToReduce(input_tv, output_tv, op_type(e), backend, comms); | ||
| lowerToReduce(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::AllToAll: | ||
| lowerToAllToAll(input_tv, output_tv, backend, comms); | ||
| lowerToAllToAll(input_tv, output_tv, params, comms); | ||
| break; | ||
| case CommunicationType::CollectivePermute: | ||
| lowerToCollectivePermute( | ||
| input_tv, output_tv, backend, comms, host_loop_index, my_device_idx); | ||
| lowerToCollectivePermute(input_tv, output_tv, params, comms); | ||
| break; | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_typelambdaThe old lambda threw an explicit, informative error when the expression was neither a
ReductionOpnor aSqueezeOp:The new lambda silently returns
std::nulloptfor any such expression. The error is now deferred to the individuallowerToReduce/lowerToAllreduce/lowerToReduceScatterfunctions, but those error messages ("Reduce communication requires reduction_op in params", etc.) don't include the actual expressionethat caused the problem.As the comment in
base.habovevalueOrErroritself notes: "If you prefer a better error message, have the caller checkhas_value()instead so it can provide more context." The current error messages do checkhas_value()but drop the expression context. Consider propagatingeinto those error messages, or restoring theNVF_THROWin the lambda for the case wheneis non-null and of an unrecognized type: