-
Notifications
You must be signed in to change notification settings - Fork 631
[JAX] Fix FSDP when FSDP+EP is active #2649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Fix FSDP when FSDP+EP is active #2649
Conversation
Greptile OverviewGreptile Summary
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant UserCode as Model/TE JAX caller
participant JAX as JAX sharding propagation
participant Gemm as GemmPrimitive._parse_operand_output_specs
UserCode->>JAX: Invoke TE gemm with sharded operands
JAX->>Gemm: Provide arg_infos (PartitionSpec per operand dim)
Gemm->>Gemm: Split specs into contracting/non-contracting
Gemm->>Gemm: If reduce_spec is None
Gemm->>Gemm: For each rhs non-contracting dim
alt spec == fsdp_resource
Gemm->>Gemm: Set rhs spec to None (gather along FSDP)
else spec is tuple and contains fsdp_resource
Gemm->>Gemm: Set rhs spec to None (gather along FSDP tuple axis)
else other spec
Gemm->>Gemm: Keep rhs spec unchanged
end
Gemm-->>JAX: Return inferred operand/output PartitionSpecs
JAX-->>UserCode: Compile with corrected sharding (avoid unnecessary all-gathers)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
4ee4d73 to
08fda9c
Compare
|
/te-ci L1 jax |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
|
/te-ci L1 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
Description
In some models when FSDP and EP are both active, the non-MoE blocks use
(('fsdp', 'expert'), None, None, None)with the EP GPU domain acting as FSDP. Our TE/JAX GEMM did not handle this correctly as it assumedfsdpwould be present in an axis alone, not as part of a tuple. This resulting in unnecessary AllGather's of the inputs blocking the critical path.This PR fixes the check and if the TE GEMM sees an inspect sharding with fsdp as part a tuple, it performs FSDP on the GPU domain of all axes specified in the tuple.
Type of change
Changes
('fsdp', 'expert')along the same axisChecklist: