Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 72 additions & 60 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ namespace nvfuser {

namespace {

struct CommunicationLoweringParams {
CommunicatorBackend backend{};
DeviceIdxType my_device_idx{};
Val* host_loop_index{};
std::optional<BinaryOpType> reduction_op;
};

// TODO: handle `c10d::RedOpType::reduceOp::AVG` and
// `c10d::RedOpType::reduceOp::PREMUL_SUM`
c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) {
Expand Down Expand Up @@ -55,7 +62,7 @@ c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) {
void lowerToScatter(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
Expand Down Expand Up @@ -83,7 +90,7 @@ void lowerToScatter(
team,
getRelativeIndex(team, root),
c10d::ReduceOp::RedOpType::UNUSED,
backend));
params.backend));
}

// Adds zero or multiple Gather communications to the vector 'comms'
Expand All @@ -93,7 +100,7 @@ void lowerToScatter(
void lowerToGather(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
// we create as many 'Gathers' as there are devices in the receiver mesh
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
Expand All @@ -114,27 +121,26 @@ void lowerToGather(
team,
getRelativeIndex(team, root),
c10d::ReduceOp::RedOpType::UNUSED,
backend));
params.backend));
}
}

// Add one or zero Allgather communication to the vector 'comms'
void lowerToAllgather(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
DeviceIdxType my_device_idx) {
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx);
Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx);
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::Allgather,
output_tv,
input_tv,
team,
/*root=*/-1,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
params.backend));
}

// Either of the following cases is happening:
Expand All @@ -144,13 +150,13 @@ void lowerToAllgather(
void lowerToBroadcast(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
Val* root,
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();

Team team = receiver_mesh.vector();
Val* root = params.host_loop_index;

if (sender_mesh == receiver_mesh) {
NVF_ERROR(
Expand All @@ -175,7 +181,7 @@ void lowerToBroadcast(
team,
root,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
params.backend));
}

// Adds several SendRecv communications to the vector 'comms'
Expand All @@ -185,7 +191,7 @@ void lowerToBroadcast(
void lowerToSendRecv(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
Expand Down Expand Up @@ -214,16 +220,18 @@ void lowerToSendRecv(
team,
/*root=*/getRelativeIndex(team, sender),
c10d::ReduceOp::RedOpType::UNUSED,
backend));
params.backend));
}
}

void lowerToReduce(
TensorView* input_tv,
TensorView* output_tv,
BinaryOpType op_type,
const CommunicatorBackend backend,
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
NVF_ERROR(
params.reduction_op.has_value(),
"Reduce communication requires reduction_op in params");
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
NVF_ERROR_EQ(
Expand All @@ -236,7 +244,8 @@ void lowerToReduce(
1,
"Reduce only supported a 1D mesh. Given ",
receiver_mesh);
const auto reduce_op_type = getC10dReduceOpType(op_type);
const auto reduce_op_type =
getC10dReduceOpType(valueOrError(params.reduction_op));
// we create as many Reduces as there are devices in the receiver mesh
for (auto root : receiver_mesh.vector()) {
Team team = sender_mesh.vector();
Expand All @@ -250,36 +259,38 @@ void lowerToReduce(
team,
getRelativeIndex(team, root),
reduce_op_type,
backend));
params.backend));
}
}

void lowerToAllreduce(
TensorView* input_tv,
TensorView* output_tv,
BinaryOpType op_type,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
DeviceIdxType my_device_idx) {
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
NVF_ERROR(
params.reduction_op.has_value(),
"Allreduce communication requires reduction_op in params");
const DeviceMesh& mesh = input_tv->getDeviceMesh();
Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx);
Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx);
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::Allreduce,
output_tv,
input_tv,
team,
/*root=*/-1,
getC10dReduceOpType(op_type),
backend));
getC10dReduceOpType(valueOrError(params.reduction_op)),
params.backend));
}

void lowerToReduceScatter(
TensorView* input_tv,
TensorView* output_tv,
BinaryOpType op_type,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
DeviceIdxType my_device_idx) {
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
NVF_ERROR(
params.reduction_op.has_value(),
"ReduceScatter communication requires reduction_op in params");
NVF_ERROR_EQ(
input_tv->getDeviceMesh(),
output_tv->getDeviceMesh(),
Expand All @@ -288,22 +299,22 @@ void lowerToReduceScatter(
"Insert a Set operation before or after the reduction to reshard to "
"another device mesh");
const DeviceMesh& mesh = input_tv->getDeviceMesh();
Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx);
Team team = mesh.getSlice(params.my_device_idx, ParallelType::DIDx);

comms.push_back(IrBuilder::create<Communication>(
CommunicationType::ReduceScatter,
output_tv,
input_tv,
/*team=*/team,
/*root=*/-1,
getC10dReduceOpType(op_type),
backend));
getC10dReduceOpType(valueOrError(params.reduction_op)),
params.backend));
}

void lowerToAllToAll(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
Expand Down Expand Up @@ -331,16 +342,14 @@ void lowerToAllToAll(
sender_mesh.vector(),
/*root=*/-1,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
params.backend));
}

void lowerToCollectivePermute(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
Val* host_loop_index,
DeviceIdxType my_device_idx) {
const CommunicationLoweringParams& params,
std::vector<Expr*>& comms) {
NVF_ERROR_EQ(
input_tv->getDeviceMesh(),
output_tv->getDeviceMesh(),
Expand All @@ -350,7 +359,7 @@ void lowerToCollectivePermute(
output_tv->getDeviceMesh());

NVF_ERROR(
host_loop_index != nullptr,
params.host_loop_index != nullptr,
"Host loop index must be provided for CollectivePermute.");

IterDomain* stream_id =
Expand All @@ -362,10 +371,13 @@ void lowerToCollectivePermute(
ParallelType pt = swizzle->parallelType();

const auto& [recv_peer, send_peer] = dispatchSwizzle1D(
host_loop_index, my_device_idx, pt, input_tv->getDeviceMesh());
params.host_loop_index,
params.my_device_idx,
pt,
input_tv->getDeviceMesh());
Team team = input_tv->getDeviceMesh().vector();
comms.push_back(IrBuilder::create<CollectivePermute>(
output_tv, input_tv, team, send_peer, recv_peer, backend));
output_tv, input_tv, team, send_peer, recv_peer, params.backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
Expand Down Expand Up @@ -657,52 +669,52 @@ std::vector<Expr*> convertSingleOpToCommunication(
"Expected communication info for resharding expr: ",
e);

auto op_type = [](Expr* e) -> BinaryOpType {
auto op_type = [](Expr* e) -> std::optional<BinaryOpType> {
if (auto* reduce = dynamic_cast<ReductionOp*>(e)) {
return reduce->getReductionOpType();
}

NVF_ERROR(e != nullptr);
if (e->isA<SqueezeOp>()) {
if (e != nullptr && e->isA<SqueezeOp>()) {
return BinaryOpType::Add;
}

NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e);
return std::nullopt;
};
Comment on lines +672 to 680
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Loss of diagnostic information in op_type lambda

The old lambda threw an explicit, informative error when the expression was neither a ReductionOp nor a SqueezeOp:

NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e);

The new lambda silently returns std::nullopt for any such expression. The error is now deferred to the individual lowerToReduce/lowerToAllreduce/lowerToReduceScatter functions, but those error messages ("Reduce communication requires reduction_op in params", etc.) don't include the actual expression e that caused the problem.

As the comment in base.h above valueOrError itself notes: "If you prefer a better error message, have the caller check has_value() instead so it can provide more context." The current error messages do check has_value() but drop the expression context. Consider propagating e into those error messages, or restoring the NVF_THROW in the lambda for the case when e is non-null and of an unrecognized type:

auto op_type = [](Expr* e) -> std::optional<BinaryOpType> {
    if (auto* reduce = dynamic_cast<ReductionOp*>(e)) {
      return reduce->getReductionOpType();
    }
    if (e != nullptr && e->isA<SqueezeOp>()) {
      return BinaryOpType::Add;
    }
    // Optionally retain the informative message for unexpected non-null expressions:
    // NVF_ERROR(e == nullptr, "Expected a ReductionOp or a SqueezeOp, but got: ", e);
    return std::nullopt;
};


CommunicationLoweringParams params{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. You may want to consider the builder design pattern used by at::TensorOptions: https://docs.pytorch.org/cppdocs/notes/tensor_creation.html#configuring-properties-of-the-tensor. It's a more fluent way to build the parameter without having to mechanically create a struct and overwrite fields.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I prefer the struct. All the fields need to be overwritten (even if I set the common values as default). If a case arises where some fields are always going to be these default values, then I can consider switching the API

.backend = backend,
.my_device_idx = my_device_idx,
.host_loop_index = host_loop_index,
.reduction_op = op_type(e)};

switch (communication_info->type) {
case CommunicationType::Scatter:
lowerToScatter(input_tv, output_tv, backend, comms);
lowerToScatter(input_tv, output_tv, params, comms);
break;
case CommunicationType::Gather:
lowerToGather(input_tv, output_tv, backend, comms);
lowerToGather(input_tv, output_tv, params, comms);
break;
case CommunicationType::Allgather:
lowerToAllgather(input_tv, output_tv, backend, comms, my_device_idx);
lowerToAllgather(input_tv, output_tv, params, comms);
break;
case CommunicationType::Broadcast:
lowerToBroadcast(input_tv, output_tv, backend, host_loop_index, comms);
lowerToBroadcast(input_tv, output_tv, params, comms);
break;
case CommunicationType::SendRecv:
lowerToSendRecv(input_tv, output_tv, backend, comms);
lowerToSendRecv(input_tv, output_tv, params, comms);
break;
case CommunicationType::ReduceScatter:
lowerToReduceScatter(
input_tv, output_tv, op_type(e), backend, comms, my_device_idx);
lowerToReduceScatter(input_tv, output_tv, params, comms);
break;
case CommunicationType::Allreduce:
lowerToAllreduce(
input_tv, output_tv, op_type(e), backend, comms, my_device_idx);
lowerToAllreduce(input_tv, output_tv, params, comms);
break;
case CommunicationType::Reduce:
lowerToReduce(input_tv, output_tv, op_type(e), backend, comms);
lowerToReduce(input_tv, output_tv, params, comms);
break;
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
lowerToAllToAll(input_tv, output_tv, params, comms);
break;
case CommunicationType::CollectivePermute:
lowerToCollectivePermute(
input_tv, output_tv, backend, comms, host_loop_index, my_device_idx);
lowerToCollectivePermute(input_tv, output_tv, params, comms);
break;
}

Expand Down
Loading