-
Notifications
You must be signed in to change notification settings - Fork 632
Add multi-precision training support to FSDP script #2662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR updates the PyTorch FSDP example ( The new selection logic integrates into the existing flow by setting Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant CLI as fsdp.py (CLI)
participant Train as train(opts)
participant TE as TransformerEngine
participant FSDP as Torch FSDP
CLI->>Train: parse_fsdp_args()
CLI->>Train: opts.precision, opts.dtype, opts.no_fp8
Train->>Train: match opts.precision
Train->>Train: choose dtype, no_fp8, recipe
Train->>TE: build TE modules (params_dtype=dtype)
Train->>FSDP: wrap with MixedPrecision(param_dtype=opts.dtype)
loop each iteration
Train->>Train: create input x (dtype=dtype)
Train->>TE: enter te.autocast(enabled=not no_fp8, recipe=recipe)
Train->>FSDP: forward/backward via FSDP-wrapped model
Train->>TE: exit te.autocast
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
| case "fp16": | ||
| dtype = torch.bfloat16 | ||
| no_fp8 = True | ||
| case "fp8": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", | ||
| help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", | ||
| ) | ||
| return parser.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).
Additional Comments (1)
|
Description
This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist:
Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work