Skip to content

Commit fe948b6

Browse files
authored
Encapsulate communication lowering variables into a struct (#6037)
This allows me to easily pass in new variables without changing individual signatures for all `lowerTo*` methods.
1 parent fcb2a48 commit fe948b6

1 file changed

Lines changed: 72 additions & 60 deletions

File tree

csrc/host_ir/lower_to_communication.cpp

Lines changed: 72 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ namespace nvfuser {
2626

2727
namespace {
2828

29+
struct CommunicationLoweringParams {
30+
CommunicatorBackend backend{};
31+
DeviceIdxType my_device_idx{};
32+
Val* host_loop_index{};
33+
std::optional<BinaryOpType> reduction_op;
34+
};
35+
2936
// TODO: handle `c10d::RedOpType::reduceOp::AVG` and
3037
// `c10d::RedOpType::reduceOp::PREMUL_SUM`
3138
c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) {
@@ -55,7 +62,7 @@ c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) {
5562
void lowerToScatter(
5663
TensorView* input_tv,
5764
TensorView* output_tv,
58-
const CommunicatorBackend backend,
65+
const CommunicationLoweringParams& params,
5966
std::vector<Expr*>& comms) {
6067
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
6168
NVF_ERROR_EQ(
@@ -83,7 +90,7 @@ void lowerToScatter(
8390
team,
8491
getRelativeIndex(team, root),
8592
c10d::ReduceOp::RedOpType::UNUSED,
86-
backend));
93+
params.backend));
8794
}
8895

8996
// Adds zero or multiple Gather communications to the vector 'comms'
@@ -93,7 +100,7 @@ void lowerToScatter(
93100
void lowerToGather(
94101
TensorView* input_tv,
95102
TensorView* output_tv,
96-
const CommunicatorBackend backend,
103+
const CommunicationLoweringParams& params,
97104
std::vector<Expr*>& comms) {
98105
// we create as many 'Gathers' as there are devices in the receiver mesh
99106
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
@@ -114,27 +121,26 @@ void lowerToGather(
114121
team,
115122
getRelativeIndex(team, root),
116123
c10d::ReduceOp::RedOpType::UNUSED,
117-
backend));
124+
params.backend));
118125
}
119126
}
120127

121128
// Add one or zero Allgather communication to the vector 'comms'
122129
void lowerToAllgather(
123130
TensorView* input_tv,
124131
TensorView* output_tv,
125-
const CommunicatorBackend backend,
126-
std::vector<Expr*>& comms,
127-
DeviceIdxType my_device_idx) {
132+
const CommunicationLoweringParams& params,
133+
std::vector<Expr*>& comms) {
128134
const DeviceMesh& mesh = input_tv->getDeviceMesh();
129-
Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx);
135+
Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx);
130136
comms.push_back(IrBuilder::create<Communication>(
131137
CommunicationType::Allgather,
132138
output_tv,
133139
input_tv,
134140
team,
135141
/*root=*/-1,
136142
c10d::ReduceOp::RedOpType::UNUSED,
137-
backend));
143+
params.backend));
138144
}
139145

140146
// Either of the following cases is happening:
@@ -144,13 +150,13 @@ void lowerToAllgather(
144150
void lowerToBroadcast(
145151
TensorView* input_tv,
146152
TensorView* output_tv,
147-
const CommunicatorBackend backend,
148-
Val* root,
153+
const CommunicationLoweringParams& params,
149154
std::vector<Expr*>& comms) {
150155
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
151156
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
152157

153158
Team team = receiver_mesh.vector();
159+
Val* root = params.host_loop_index;
154160

155161
if (sender_mesh == receiver_mesh) {
156162
NVF_ERROR(
@@ -175,7 +181,7 @@ void lowerToBroadcast(
175181
team,
176182
root,
177183
c10d::ReduceOp::RedOpType::UNUSED,
178-
backend));
184+
params.backend));
179185
}
180186

181187
// Adds several SendRecv communications to the vector 'comms'
@@ -185,7 +191,7 @@ void lowerToBroadcast(
185191
void lowerToSendRecv(
186192
TensorView* input_tv,
187193
TensorView* output_tv,
188-
const CommunicatorBackend backend,
194+
const CommunicationLoweringParams& params,
189195
std::vector<Expr*>& comms) {
190196
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
191197
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
@@ -214,16 +220,18 @@ void lowerToSendRecv(
214220
team,
215221
/*root=*/getRelativeIndex(team, sender),
216222
c10d::ReduceOp::RedOpType::UNUSED,
217-
backend));
223+
params.backend));
218224
}
219225
}
220226

221227
void lowerToReduce(
222228
TensorView* input_tv,
223229
TensorView* output_tv,
224-
BinaryOpType op_type,
225-
const CommunicatorBackend backend,
230+
const CommunicationLoweringParams& params,
226231
std::vector<Expr*>& comms) {
232+
NVF_ERROR(
233+
params.reduction_op.has_value(),
234+
"Reduce communication requires reduction_op in params");
227235
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
228236
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
229237
NVF_ERROR_EQ(
@@ -236,7 +244,8 @@ void lowerToReduce(
236244
1,
237245
"Reduce only supported a 1D mesh. Given ",
238246
receiver_mesh);
239-
const auto reduce_op_type = getC10dReduceOpType(op_type);
247+
const auto reduce_op_type =
248+
getC10dReduceOpType(valueOrError(params.reduction_op));
240249
// we create as many Reduces as there are devices in the receiver mesh
241250
for (auto root : receiver_mesh.vector()) {
242251
Team team = sender_mesh.vector();
@@ -250,36 +259,38 @@ void lowerToReduce(
250259
team,
251260
getRelativeIndex(team, root),
252261
reduce_op_type,
253-
backend));
262+
params.backend));
254263
}
255264
}
256265

257266
void lowerToAllreduce(
258267
TensorView* input_tv,
259268
TensorView* output_tv,
260-
BinaryOpType op_type,
261-
const CommunicatorBackend backend,
262-
std::vector<Expr*>& comms,
263-
DeviceIdxType my_device_idx) {
269+
const CommunicationLoweringParams& params,
270+
std::vector<Expr*>& comms) {
271+
NVF_ERROR(
272+
params.reduction_op.has_value(),
273+
"Allreduce communication requires reduction_op in params");
264274
const DeviceMesh& mesh = input_tv->getDeviceMesh();
265-
Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx);
275+
Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx);
266276
comms.push_back(IrBuilder::create<Communication>(
267277
CommunicationType::Allreduce,
268278
output_tv,
269279
input_tv,
270280
team,
271281
/*root=*/-1,
272-
getC10dReduceOpType(op_type),
273-
backend));
282+
getC10dReduceOpType(valueOrError(params.reduction_op)),
283+
params.backend));
274284
}
275285

276286
void lowerToReduceScatter(
277287
TensorView* input_tv,
278288
TensorView* output_tv,
279-
BinaryOpType op_type,
280-
const CommunicatorBackend backend,
281-
std::vector<Expr*>& comms,
282-
DeviceIdxType my_device_idx) {
289+
const CommunicationLoweringParams& params,
290+
std::vector<Expr*>& comms) {
291+
NVF_ERROR(
292+
params.reduction_op.has_value(),
293+
"ReduceScatter communication requires reduction_op in params");
283294
NVF_ERROR_EQ(
284295
input_tv->getDeviceMesh(),
285296
output_tv->getDeviceMesh(),
@@ -288,22 +299,22 @@ void lowerToReduceScatter(
288299
"Insert a Set operation before or after the reduction to reshard to "
289300
"another device mesh");
290301
const DeviceMesh& mesh = input_tv->getDeviceMesh();
291-
Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx);
302+
Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx);
292303

293304
comms.push_back(IrBuilder::create<Communication>(
294305
CommunicationType::ReduceScatter,
295306
output_tv,
296307
input_tv,
297308
/*team=*/team,
298309
/*root=*/-1,
299-
getC10dReduceOpType(op_type),
300-
backend));
310+
getC10dReduceOpType(valueOrError(params.reduction_op)),
311+
params.backend));
301312
}
302313

303314
void lowerToAllToAll(
304315
TensorView* input_tv,
305316
TensorView* output_tv,
306-
const CommunicatorBackend backend,
317+
const CommunicationLoweringParams& params,
307318
std::vector<Expr*>& comms) {
308319
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
309320
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
@@ -331,16 +342,14 @@ void lowerToAllToAll(
331342
sender_mesh.vector(),
332343
/*root=*/-1,
333344
c10d::ReduceOp::RedOpType::UNUSED,
334-
backend));
345+
params.backend));
335346
}
336347

337348
void lowerToCollectivePermute(
338349
TensorView* input_tv,
339350
TensorView* output_tv,
340-
const CommunicatorBackend backend,
341-
std::vector<Expr*>& comms,
342-
Val* host_loop_index,
343-
DeviceIdxType my_device_idx) {
351+
const CommunicationLoweringParams& params,
352+
std::vector<Expr*>& comms) {
344353
NVF_ERROR_EQ(
345354
input_tv->getDeviceMesh(),
346355
output_tv->getDeviceMesh(),
@@ -350,7 +359,7 @@ void lowerToCollectivePermute(
350359
output_tv->getDeviceMesh());
351360

352361
NVF_ERROR(
353-
host_loop_index != nullptr,
362+
params.host_loop_index != nullptr,
354363
"Host loop index must be provided for CollectivePermute.");
355364

356365
IterDomain* stream_id =
@@ -362,10 +371,13 @@ void lowerToCollectivePermute(
362371
ParallelType pt = swizzle->parallelType();
363372

364373
const auto& [recv_peer, send_peer] = dispatchSwizzle1D(
365-
host_loop_index, my_device_idx, pt, input_tv->getDeviceMesh());
374+
params.host_loop_index,
375+
params.my_device_idx,
376+
pt,
377+
input_tv->getDeviceMesh());
366378
Team team = input_tv->getDeviceMesh().vector();
367379
comms.push_back(IrBuilder::create<CollectivePermute>(
368-
output_tv, input_tv, team, send_peer, recv_peer, backend));
380+
output_tv, input_tv, team, send_peer, recv_peer, params.backend));
369381
}
370382

371383
IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
@@ -657,52 +669,52 @@ std::vector<Expr*> convertSingleOpToCommunication(
657669
"Expected communication info for resharding expr: ",
658670
e);
659671

660-
auto op_type = [](Expr* e) -> BinaryOpType {
672+
auto op_type = [](Expr* e) -> std::optional<BinaryOpType> {
661673
if (auto* reduce = dynamic_cast<ReductionOp*>(e)) {
662674
return reduce->getReductionOpType();
663675
}
664-
665-
NVF_ERROR(e != nullptr);
666-
if (e->isA<SqueezeOp>()) {
676+
if (e != nullptr && e->isA<SqueezeOp>()) {
667677
return BinaryOpType::Add;
668678
}
669-
670-
NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e);
679+
return std::nullopt;
671680
};
672681

682+
CommunicationLoweringParams params{
683+
.backend = backend,
684+
.my_device_idx = my_device_idx,
685+
.host_loop_index = host_loop_index,
686+
.reduction_op = op_type(e)};
687+
673688
switch (communication_info->type) {
674689
case CommunicationType::Scatter:
675-
lowerToScatter(input_tv, output_tv, backend, comms);
690+
lowerToScatter(input_tv, output_tv, params, comms);
676691
break;
677692
case CommunicationType::Gather:
678-
lowerToGather(input_tv, output_tv, backend, comms);
693+
lowerToGather(input_tv, output_tv, params, comms);
679694
break;
680695
case CommunicationType::Allgather:
681-
lowerToAllgather(input_tv, output_tv, backend, comms, my_device_idx);
696+
lowerToAllgather(input_tv, output_tv, params, comms);
682697
break;
683698
case CommunicationType::Broadcast:
684-
lowerToBroadcast(input_tv, output_tv, backend, host_loop_index, comms);
699+
lowerToBroadcast(input_tv, output_tv, params, comms);
685700
break;
686701
case CommunicationType::SendRecv:
687-
lowerToSendRecv(input_tv, output_tv, backend, comms);
702+
lowerToSendRecv(input_tv, output_tv, params, comms);
688703
break;
689704
case CommunicationType::ReduceScatter:
690-
lowerToReduceScatter(
691-
input_tv, output_tv, op_type(e), backend, comms, my_device_idx);
705+
lowerToReduceScatter(input_tv, output_tv, params, comms);
692706
break;
693707
case CommunicationType::Allreduce:
694-
lowerToAllreduce(
695-
input_tv, output_tv, op_type(e), backend, comms, my_device_idx);
708+
lowerToAllreduce(input_tv, output_tv, params, comms);
696709
break;
697710
case CommunicationType::Reduce:
698-
lowerToReduce(input_tv, output_tv, op_type(e), backend, comms);
711+
lowerToReduce(input_tv, output_tv, params, comms);
699712
break;
700713
case CommunicationType::AllToAll:
701-
lowerToAllToAll(input_tv, output_tv, backend, comms);
714+
lowerToAllToAll(input_tv, output_tv, params, comms);
702715
break;
703716
case CommunicationType::CollectivePermute:
704-
lowerToCollectivePermute(
705-
input_tv, output_tv, backend, comms, host_loop_index, my_device_idx);
717+
lowerToCollectivePermute(input_tv, output_tv, params, comms);
706718
break;
707719
}
708720

0 commit comments

Comments
 (0)