-
Notifications
You must be signed in to change notification settings - Fork 32
Refactor CCL APIs to align with torch.distributed conventions #326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This refactor reorders parameters and adds support for process groups: API Changes: - all_reduce: (out, in, op=SUM, group=None, async_op=False, config=None, workspace=None) - reduce_scatter: (out, in, op=SUM, group=None, async_op=False, config=None) - all_gather: (out, in, group=None, async_op=False, config=None) - all_to_all: (out, in, group=None, async_op=False, config=None) New Features: - Add ReduceOp enum (SUM, PRODUCT, MIN, MAX, etc.) matching torch.distributed - Add extract_group_info() helper to extract rank_start/rank_stride from ProcessGroup - Support strided process groups (e.g., TP groups [0,1,2,3] or DP groups [0,4,8,12]) - op parameter validates only SUM is used (other ops to be added later) Kernel Changes: - All CCL kernels now accept rank_start and rank_stride constexpr parameters - Kernel loops updated to iterate using group-aware rank calculation - Ring all-reduce computes next_rank on host side for group support Backward Compatibility: - Existing code using keyword arguments (config=...) continues to work - torch.distributed compatible parameter ordering (group before config)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request refactors the CCL (Collective Communication Library) APIs to align with torch.distributed conventions by reordering parameters and adding support for process groups. However, the implementation contains several critical bugs that prevent process groups from working correctly.
Changes:
- Adds ReduceOp enum matching torch.distributed semantics
- Reorders API parameters to match torch.distributed: (out, in, op, group, async_op, config)
- Adds extract_group_info() helper to extract rank/stride information from ProcessGroup
- Updates all CCL kernels to accept rank_start and rank_stride parameters for group-aware rank calculation
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| iris/iris.py | Updated all CCL method signatures to add op and group parameters with reordered arguments |
| iris/experimental/iris_gluon.py | Updated CCL method signatures (missing group parameter in all_gather) |
| iris/ccl/init.py | Added ReduceOp to exports |
| iris/ccl/utils.py | Added ReduceOp enum and extract_group_info() helper function |
| iris/ccl/all_reduce.py | Updated kernels and function to support group parameters with rank_start/rank_stride |
| iris/ccl/reduce_scatter.py | Updated kernel and function to support group parameters |
| iris/ccl/all_gather.py | Updated kernel and function to support group parameters |
| iris/ccl/all_to_all.py | Updated Triton and Gluon kernels and function to support group parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
mawad-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we find better names for cur_rank and cur_rank_global? They are really confusing throughout the code.
| rank_start: tl.constexpr, | ||
| rank_stride: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a valid concern. Is an iris instance over the entire world or just the GPUs belonging to the group?
Co-authored-by: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com>
Let me think about that. |
mawad-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks!
Add optional parameter to barrier methods to support process group-specific synchronization, aligning with the torch.distributed.barrier(group=None) API convention.
This refactor reorders parameters and adds support for process groups:
API Changes:
New:
Kernel Changes: