Skip to content

Conversation

@aagallo
Copy link

@aagallo aagallo commented Feb 9, 2026

Description

This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added precision() type validator function supporting fp32, fp16, fp8, mxfp8, and nvfp4 formats
  • Added --precision command-line argument to parse_fsdp_args() with default value "fp8"
  • Implemented match statement in train() function to configure precision-based training parameters
  • Configured format-specific recipes for each precision type:
    • FP32/FP16: Uses standard PyTorch dtypes with FP8 disabled
    • FP8: Uses DelayedScaling recipe with HYBRID format
    • MXFP8: Uses MXFP8BlockScaling recipe with E4M3 format
    • NVFP4: Uses NVFP4BlockScaling recipe with bfloat16 dtype
  • Set appropriate no_fp8 flags based on precision selection
  • Updated layer_kwargs["params_dtype"] to use precision-determined dtype
  • Imported required recipe classes: MXFP8BlockScaling and NVFP4BlockScaling

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work

aagallo and others added 3 commits February 9, 2026 16:28
Enable configurable precision training with support for FP32, FP16, FP8,
MXFP8, and NVFP4 formats. Added precision argument parser and match statement
to configure appropriate dtype and recipe based on selected precision.

- Add precision() type validator function
- Implement precision-based configuration in train()
- Support FP32, FP16, FP8, MXFP8, and NVFP4 formats
- Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling)
- Set appropriate no_fp8 flags based on precision selection

Signed-off-by: aagallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Overview

Greptile Summary

This PR updates the PyTorch FSDP example (examples/pytorch/fsdp/fsdp.py) to add a --precision CLI argument and route training configuration through it. The script now selects a TE recipe (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling), an input/parameter dtype, and whether to enable te.autocast based on the requested precision.

The new selection logic integrates into the existing flow by setting layer_kwargs["params_dtype"], using the selected dtype for the synthetic input tensor, and passing the selected recipe into te.autocast during the forward pass.

Confidence Score: 2/5

  • This PR has functional issues that can lead to incorrect dtype behavior under FSDP and surprising CLI semantics.
  • The --precision switch overrides model/input dtypes and FP8 enablement, but FSDP MixedPrecision is still driven by opts.dtype, creating a deterministic mismatch for some precision selections (e.g., fp32). Additionally, fp16 currently maps to bf16, and --precision can silently override explicit --no-fp8/--dtype flags, making runs behave differently than requested.
  • examples/pytorch/fsdp/fsdp.py

Important Files Changed

Filename Overview
examples/pytorch/fsdp/fsdp.py Adds --precision and selects TE autocast recipe/dtype based on it, but introduces mismatches between selected dtype and FSDP MixedPrecision(param_dtype=opts.dtype) and sets fp16 to bf16; CLI precedence between --precision and --dtype/--no-fp8 is also inconsistent.

Sequence Diagram

sequenceDiagram
    participant CLI as fsdp.py (CLI)
    participant Train as train(opts)
    participant TE as TransformerEngine
    participant FSDP as Torch FSDP

    CLI->>Train: parse_fsdp_args()
    CLI->>Train: opts.precision, opts.dtype, opts.no_fp8
    Train->>Train: match opts.precision
    Train->>Train: choose dtype, no_fp8, recipe
    Train->>TE: build TE modules (params_dtype=dtype)
    Train->>FSDP: wrap with MixedPrecision(param_dtype=opts.dtype)

    loop each iteration
        Train->>Train: create input x (dtype=dtype)
        Train->>TE: enter te.autocast(enabled=not no_fp8, recipe=recipe)
        Train->>FSDP: forward/backward via FSDP-wrapped model
        Train->>TE: exit te.autocast
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +243 to +246
case "fp16":
dtype = torch.bfloat16
no_fp8 = True
case "fp8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.

Comment on lines +206 to 212
parser.add_argument(
"--precision",
type=precision,
default="fp8",
help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)",
)
return parser.parse_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP mixed_precision mismatch
layer_kwargs["params_dtype"] and the input tensor dtype are overridden by --precision, but FSDP is still configured with mixed_precision=MixedPrecision(param_dtype=opts.dtype, ...) (and opts.dtype no longer matches the model param dtype when --precision is used). This will cause inconsistent param casting/communication behavior under FSDP for e.g. --precision fp32 (params are fp32 but FSDP thinks they’re bf16) and --precision fp16 (currently sets dtype=torch.bfloat16). FSDP param_dtype should be driven by the same dtype selected in the precision switch, or the precision switch should not override param dtype when FSDP mixed precision is enabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant