Skip to content

Encapsulate communication lowering variables into a struct#6037

Merged
Priya2698 merged 3 commits intomainfrom
pm/lowering_params
Mar 20, 2026
Merged

Encapsulate communication lowering variables into a struct#6037
Priya2698 merged 3 commits intomainfrom
pm/lowering_params

Conversation

@Priya2698
Copy link
Collaborator

This allows me to easily pass in new variables without changing individual signatures for all lowerTo* methods.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR encapsulates the four communication-lowering variables (backend, my_device_idx, host_loop_index, reduction_op) into a new CommunicationLoweringParams struct, simplifying all lowerTo* function signatures from 4–6 individual parameters down to a single const CommunicationLoweringParams& params. The motivation is to make future additions (new variables) signature-stable.

Key changes:

  • New POD struct CommunicationLoweringParams added in the anonymous namespace with sensible zero-default members.
  • All ten lowerTo* helpers updated to accept const CommunicationLoweringParams&.
  • The op_type lambda now returns std::optional<BinaryOpType> and is evaluated eagerly once for all communication types; reduction functions guard with NVF_ERROR(params.reduction_op.has_value(), ...) before calling valueOrError().
  • lowerToBroadcast's root parameter is now derived from params.host_loop_index inside the function body, preserving existing semantics.

The refactoring is clean and functionally equivalent. The one pre-existing concern (noted in the previous review thread) is that the op_type lambda now silently returns std::nullopt for unrecognised expression types, so the per-reduction error messages no longer include the offending expression e.

Confidence Score: 4/5

  • Safe to merge — pure refactoring with no functional changes to communication logic.
  • All ten lowerTo* call sites are updated consistently, designated-initializer struct construction is explicit (no field omitted), and the functional behaviour of every communication path is preserved. The only concern is the reduction in diagnostic quality for an error path (already flagged in a prior thread), which is not a correctness issue.
  • No files require special attention beyond the already-discussed diagnostic loss in the op_type lambda.

Important Files Changed

Filename Overview
csrc/host_ir/lower_to_communication.cpp Introduces CommunicationLoweringParams struct to bundle backend, my_device_idx, host_loop_index, and reduction_op; all lowerTo* function signatures simplified to take const CommunicationLoweringParams& params instead of individual parameters. The op_type lambda is now eager and returns std::optional<BinaryOpType>, losing the informative NVF_THROW diagnostic for unknown expression types (flagged in a prior review thread).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["convertSingleOpToCommunication(e, my_device_idx, host_loop_index, backend)"]
    B["op_type(e)\n→ std::optional&lt;BinaryOpType&gt;"]
    C["Build CommunicationLoweringParams\n{backend, my_device_idx,\n host_loop_index, reduction_op}"]
    D{communication_info->type}

    A --> B
    B --> C
    C --> D

    D --> E["lowerToScatter(params)"]
    D --> F["lowerToGather(params)"]
    D --> G["lowerToAllgather(params)\nuses my_device_idx"]
    D --> H["lowerToBroadcast(params)\nroot = host_loop_index"]
    D --> I["lowerToSendRecv(params)"]
    D --> J["lowerToReduce(params)\nchecks reduction_op"]
    D --> K["lowerToAllreduce(params)\nchecks reduction_op, my_device_idx"]
    D --> L["lowerToReduceScatter(params)\nchecks reduction_op, my_device_idx"]
    D --> M["lowerToAllToAll(params)"]
    D --> N["lowerToCollectivePermute(params)\nchecks host_loop_index != nullptr,\nuses my_device_idx"]
Loading

Last reviewed commit: "Merge branch 'main' ..."

Comment on lines +672 to 680
auto op_type = [](Expr* e) -> std::optional<BinaryOpType> {
if (auto* reduce = dynamic_cast<ReductionOp*>(e)) {
return reduce->getReductionOpType();
}

NVF_ERROR(e != nullptr);
if (e->isA<SqueezeOp>()) {
if (e != nullptr && e->isA<SqueezeOp>()) {
return BinaryOpType::Add;
}

NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e);
return std::nullopt;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Loss of diagnostic information in op_type lambda

The old lambda threw an explicit, informative error when the expression was neither a ReductionOp nor a SqueezeOp:

NVF_THROW("Expected a ReductionOp or a SqueezeOp, but got: ", e);

The new lambda silently returns std::nullopt for any such expression. The error is now deferred to the individual lowerToReduce/lowerToAllreduce/lowerToReduceScatter functions, but those error messages ("Reduce communication requires reduction_op in params", etc.) don't include the actual expression e that caused the problem.

As the comment in base.h above valueOrError itself notes: "If you prefer a better error message, have the caller check has_value() instead so it can provide more context." The current error messages do check has_value() but drop the expression context. Consider propagating e into those error messages, or restoring the NVF_THROW in the lambda for the case when e is non-null and of an unrecognized type:

auto op_type = [](Expr* e) -> std::optional<BinaryOpType> {
    if (auto* reduce = dynamic_cast<ReductionOp*>(e)) {
      return reduce->getReductionOpType();
    }
    if (e != nullptr && e->isA<SqueezeOp>()) {
      return BinaryOpType::Add;
    }
    // Optionally retain the informative message for unexpected non-null expressions:
    // NVF_ERROR(e == nullptr, "Expected a ReductionOp or a SqueezeOp, but got: ", e);
    return std::nullopt;
};

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 requested a review from wujingyue March 18, 2026 18:56
return std::nullopt;
};

CommunicationLoweringParams params{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. You may want to consider the builder design pattern used by at::TensorOptions: https://docs.pytorch.org/cppdocs/notes/tensor_creation.html#configuring-properties-of-the-tensor. It's a more fluent way to build the parameter without having to mechanically create a struct and overwrite fields.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I prefer the struct. All the fields need to be overwritten (even if I set the common values as default). If a case arises where some fields are always going to be these default values, then I can consider switching the API

@Priya2698 Priya2698 force-pushed the pm/lowering_params branch from a3cb641 to 2a66774 Compare March 18, 2026 22:19
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 merged commit fe948b6 into main Mar 20, 2026
43 of 44 checks passed
@Priya2698 Priya2698 deleted the pm/lowering_params branch March 20, 2026 22:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants