@@ -26,6 +26,13 @@ namespace nvfuser {
2626
2727namespace {
2828
29+ struct CommunicationLoweringParams {
30+ CommunicatorBackend backend{};
31+ DeviceIdxType my_device_idx{};
32+ Val* host_loop_index{};
33+ std::optional<BinaryOpType> reduction_op;
34+ };
35+
2936// TODO: handle `c10d::RedOpType::reduceOp::AVG` and
3037// `c10d::RedOpType::reduceOp::PREMUL_SUM`
3138c10d::ReduceOp::RedOpType getC10dReduceOpType (BinaryOpType op) {
@@ -55,7 +62,7 @@ c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) {
5562void lowerToScatter (
5663 TensorView* input_tv,
5764 TensorView* output_tv,
58- const CommunicatorBackend backend ,
65+ const CommunicationLoweringParams& params ,
5966 std::vector<Expr*>& comms) {
6067 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
6168 NVF_ERROR_EQ (
@@ -83,7 +90,7 @@ void lowerToScatter(
8390 team,
8491 getRelativeIndex (team, root),
8592 c10d::ReduceOp::RedOpType::UNUSED,
86- backend));
93+ params. backend ));
8794}
8895
8996// Adds zero or multiple Gather communications to the vector 'comms'
@@ -93,7 +100,7 @@ void lowerToScatter(
93100void lowerToGather (
94101 TensorView* input_tv,
95102 TensorView* output_tv,
96- const CommunicatorBackend backend ,
103+ const CommunicationLoweringParams& params ,
97104 std::vector<Expr*>& comms) {
98105 // we create as many 'Gathers' as there are devices in the receiver mesh
99106 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
@@ -114,27 +121,26 @@ void lowerToGather(
114121 team,
115122 getRelativeIndex (team, root),
116123 c10d::ReduceOp::RedOpType::UNUSED,
117- backend));
124+ params. backend ));
118125 }
119126}
120127
121128// Add one or zero Allgather communication to the vector 'comms'
122129void lowerToAllgather (
123130 TensorView* input_tv,
124131 TensorView* output_tv,
125- const CommunicatorBackend backend,
126- std::vector<Expr*>& comms,
127- DeviceIdxType my_device_idx) {
132+ const CommunicationLoweringParams& params,
133+ std::vector<Expr*>& comms) {
128134 const DeviceMesh& mesh = input_tv->getDeviceMesh ();
129- Team team = mesh.getSlice (my_device_idx, ParallelType::DIDx);
135+ Team team = mesh.getSlice (params. my_device_idx , ParallelType::DIDx);
130136 comms.push_back (IrBuilder::create<Communication>(
131137 CommunicationType::Allgather,
132138 output_tv,
133139 input_tv,
134140 team,
135141 /* root=*/ -1 ,
136142 c10d::ReduceOp::RedOpType::UNUSED,
137- backend));
143+ params. backend ));
138144}
139145
140146// Either of the following cases is happening:
@@ -144,13 +150,13 @@ void lowerToAllgather(
144150void lowerToBroadcast (
145151 TensorView* input_tv,
146152 TensorView* output_tv,
147- const CommunicatorBackend backend,
148- Val* root,
153+ const CommunicationLoweringParams& params,
149154 std::vector<Expr*>& comms) {
150155 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
151156 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
152157
153158 Team team = receiver_mesh.vector ();
159+ Val* root = params.host_loop_index ;
154160
155161 if (sender_mesh == receiver_mesh) {
156162 NVF_ERROR (
@@ -175,7 +181,7 @@ void lowerToBroadcast(
175181 team,
176182 root,
177183 c10d::ReduceOp::RedOpType::UNUSED,
178- backend));
184+ params. backend ));
179185}
180186
181187// Adds several SendRecv communications to the vector 'comms'
@@ -185,7 +191,7 @@ void lowerToBroadcast(
185191void lowerToSendRecv (
186192 TensorView* input_tv,
187193 TensorView* output_tv,
188- const CommunicatorBackend backend ,
194+ const CommunicationLoweringParams& params ,
189195 std::vector<Expr*>& comms) {
190196 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
191197 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
@@ -214,16 +220,18 @@ void lowerToSendRecv(
214220 team,
215221 /* root=*/ getRelativeIndex (team, sender),
216222 c10d::ReduceOp::RedOpType::UNUSED,
217- backend));
223+ params. backend ));
218224 }
219225}
220226
221227void lowerToReduce (
222228 TensorView* input_tv,
223229 TensorView* output_tv,
224- BinaryOpType op_type,
225- const CommunicatorBackend backend,
230+ const CommunicationLoweringParams& params,
226231 std::vector<Expr*>& comms) {
232+ NVF_ERROR (
233+ params.reduction_op .has_value (),
234+ " Reduce communication requires reduction_op in params" );
227235 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
228236 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
229237 NVF_ERROR_EQ (
@@ -236,7 +244,8 @@ void lowerToReduce(
236244 1 ,
237245 " Reduce only supported a 1D mesh. Given " ,
238246 receiver_mesh);
239- const auto reduce_op_type = getC10dReduceOpType (op_type);
247+ const auto reduce_op_type =
248+ getC10dReduceOpType (valueOrError (params.reduction_op ));
240249 // we create as many Reduces as there are devices in the receiver mesh
241250 for (auto root : receiver_mesh.vector ()) {
242251 Team team = sender_mesh.vector ();
@@ -250,36 +259,38 @@ void lowerToReduce(
250259 team,
251260 getRelativeIndex (team, root),
252261 reduce_op_type,
253- backend));
262+ params. backend ));
254263 }
255264}
256265
257266void lowerToAllreduce (
258267 TensorView* input_tv,
259268 TensorView* output_tv,
260- BinaryOpType op_type,
261- const CommunicatorBackend backend,
262- std::vector<Expr*>& comms,
263- DeviceIdxType my_device_idx) {
269+ const CommunicationLoweringParams& params,
270+ std::vector<Expr*>& comms) {
271+ NVF_ERROR (
272+ params.reduction_op .has_value (),
273+ " Allreduce communication requires reduction_op in params" );
264274 const DeviceMesh& mesh = input_tv->getDeviceMesh ();
265- Team team = mesh.getSlice (my_device_idx, ParallelType::DIDx);
275+ Team team = mesh.getSlice (params. my_device_idx , ParallelType::DIDx);
266276 comms.push_back (IrBuilder::create<Communication>(
267277 CommunicationType::Allreduce,
268278 output_tv,
269279 input_tv,
270280 team,
271281 /* root=*/ -1 ,
272- getC10dReduceOpType (op_type ),
273- backend));
282+ getC10dReduceOpType (valueOrError (params. reduction_op ) ),
283+ params. backend ));
274284}
275285
276286void lowerToReduceScatter (
277287 TensorView* input_tv,
278288 TensorView* output_tv,
279- BinaryOpType op_type,
280- const CommunicatorBackend backend,
281- std::vector<Expr*>& comms,
282- DeviceIdxType my_device_idx) {
289+ const CommunicationLoweringParams& params,
290+ std::vector<Expr*>& comms) {
291+ NVF_ERROR (
292+ params.reduction_op .has_value (),
293+ " ReduceScatter communication requires reduction_op in params" );
283294 NVF_ERROR_EQ (
284295 input_tv->getDeviceMesh (),
285296 output_tv->getDeviceMesh (),
@@ -288,22 +299,22 @@ void lowerToReduceScatter(
288299 " Insert a Set operation before or after the reduction to reshard to "
289300 " another device mesh" );
290301 const DeviceMesh& mesh = input_tv->getDeviceMesh ();
291- Team team = mesh.getSlice (my_device_idx, ParallelType::DIDx);
302+ Team team = mesh.getSlice (params. my_device_idx , ParallelType::DIDx);
292303
293304 comms.push_back (IrBuilder::create<Communication>(
294305 CommunicationType::ReduceScatter,
295306 output_tv,
296307 input_tv,
297308 /* team=*/ team,
298309 /* root=*/ -1 ,
299- getC10dReduceOpType (op_type ),
300- backend));
310+ getC10dReduceOpType (valueOrError (params. reduction_op ) ),
311+ params. backend ));
301312}
302313
303314void lowerToAllToAll (
304315 TensorView* input_tv,
305316 TensorView* output_tv,
306- const CommunicatorBackend backend ,
317+ const CommunicationLoweringParams& params ,
307318 std::vector<Expr*>& comms) {
308319 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
309320 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
@@ -331,16 +342,14 @@ void lowerToAllToAll(
331342 sender_mesh.vector (),
332343 /* root=*/ -1 ,
333344 c10d::ReduceOp::RedOpType::UNUSED,
334- backend));
345+ params. backend ));
335346}
336347
337348void lowerToCollectivePermute (
338349 TensorView* input_tv,
339350 TensorView* output_tv,
340- const CommunicatorBackend backend,
341- std::vector<Expr*>& comms,
342- Val* host_loop_index,
343- DeviceIdxType my_device_idx) {
351+ const CommunicationLoweringParams& params,
352+ std::vector<Expr*>& comms) {
344353 NVF_ERROR_EQ (
345354 input_tv->getDeviceMesh (),
346355 output_tv->getDeviceMesh (),
@@ -350,7 +359,7 @@ void lowerToCollectivePermute(
350359 output_tv->getDeviceMesh ());
351360
352361 NVF_ERROR (
353- host_loop_index != nullptr ,
362+ params. host_loop_index != nullptr ,
354363 " Host loop index must be provided for CollectivePermute." );
355364
356365 IterDomain* stream_id =
@@ -362,10 +371,13 @@ void lowerToCollectivePermute(
362371 ParallelType pt = swizzle->parallelType ();
363372
364373 const auto & [recv_peer, send_peer] = dispatchSwizzle1D (
365- host_loop_index, my_device_idx, pt, input_tv->getDeviceMesh ());
374+ params.host_loop_index ,
375+ params.my_device_idx ,
376+ pt,
377+ input_tv->getDeviceMesh ());
366378 Team team = input_tv->getDeviceMesh ().vector ();
367379 comms.push_back (IrBuilder::create<CollectivePermute>(
368- output_tv, input_tv, team, send_peer, recv_peer, backend));
380+ output_tv, input_tv, team, send_peer, recv_peer, params. backend ));
369381}
370382
371383IterDomain* getLogicalFromLoopId (TensorView* tv, IterDomain* loop_id) {
@@ -657,52 +669,52 @@ std::vector<Expr*> convertSingleOpToCommunication(
657669 " Expected communication info for resharding expr: " ,
658670 e);
659671
660- auto op_type = [](Expr* e) -> BinaryOpType {
672+ auto op_type = [](Expr* e) -> std::optional< BinaryOpType> {
661673 if (auto * reduce = dynamic_cast <ReductionOp*>(e)) {
662674 return reduce->getReductionOpType ();
663675 }
664-
665- NVF_ERROR (e != nullptr );
666- if (e->isA <SqueezeOp>()) {
676+ if (e != nullptr && e->isA <SqueezeOp>()) {
667677 return BinaryOpType::Add;
668678 }
669-
670- NVF_THROW (" Expected a ReductionOp or a SqueezeOp, but got: " , e);
679+ return std::nullopt ;
671680 };
672681
682+ CommunicationLoweringParams params{
683+ .backend = backend,
684+ .my_device_idx = my_device_idx,
685+ .host_loop_index = host_loop_index,
686+ .reduction_op = op_type (e)};
687+
673688 switch (communication_info->type ) {
674689 case CommunicationType::Scatter:
675- lowerToScatter (input_tv, output_tv, backend , comms);
690+ lowerToScatter (input_tv, output_tv, params , comms);
676691 break ;
677692 case CommunicationType::Gather:
678- lowerToGather (input_tv, output_tv, backend , comms);
693+ lowerToGather (input_tv, output_tv, params , comms);
679694 break ;
680695 case CommunicationType::Allgather:
681- lowerToAllgather (input_tv, output_tv, backend , comms, my_device_idx );
696+ lowerToAllgather (input_tv, output_tv, params , comms);
682697 break ;
683698 case CommunicationType::Broadcast:
684- lowerToBroadcast (input_tv, output_tv, backend, host_loop_index , comms);
699+ lowerToBroadcast (input_tv, output_tv, params , comms);
685700 break ;
686701 case CommunicationType::SendRecv:
687- lowerToSendRecv (input_tv, output_tv, backend , comms);
702+ lowerToSendRecv (input_tv, output_tv, params , comms);
688703 break ;
689704 case CommunicationType::ReduceScatter:
690- lowerToReduceScatter (
691- input_tv, output_tv, op_type (e), backend, comms, my_device_idx);
705+ lowerToReduceScatter (input_tv, output_tv, params, comms);
692706 break ;
693707 case CommunicationType::Allreduce:
694- lowerToAllreduce (
695- input_tv, output_tv, op_type (e), backend, comms, my_device_idx);
708+ lowerToAllreduce (input_tv, output_tv, params, comms);
696709 break ;
697710 case CommunicationType::Reduce:
698- lowerToReduce (input_tv, output_tv, op_type (e), backend , comms);
711+ lowerToReduce (input_tv, output_tv, params , comms);
699712 break ;
700713 case CommunicationType::AllToAll:
701- lowerToAllToAll (input_tv, output_tv, backend , comms);
714+ lowerToAllToAll (input_tv, output_tv, params , comms);
702715 break ;
703716 case CommunicationType::CollectivePermute:
704- lowerToCollectivePermute (
705- input_tv, output_tv, backend, comms, host_loop_index, my_device_idx);
717+ lowerToCollectivePermute (input_tv, output_tv, params, comms);
706718 break ;
707719 }
708720
0 commit comments