Skip to content
Open
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c960e6d
Add precision parameter support for multiple training formats
Feb 9, 2026
5949884
Merge branch 'NVIDIA:main' into extend-precision
aagallo Feb 9, 2026
cd97843
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
1341abb
Merge branch 'NVIDIA:main' into main
aagallo Feb 10, 2026
40919d8
Fix FP16 dtype mapping and implement CLI flag precedence
aagallo Feb 10, 2026
5c1db12
Add logging and documentation for precision configuration
aagallo Feb 10, 2026
e4846c8
Initialize recipe variable in all precision cases
aagallo Feb 10, 2026
26aee2f
Fix dtype flag detection to support explicit override behavior
aagallo Feb 10, 2026
295a106
Merge remote-tracking branch 'origin/main' into extend-precision
aagallo Feb 10, 2026
c9524e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
a506d40
Replace sys.argv parsing with custom action and fix default case
aagallo Feb 10, 2026
b100f2c
Merge branch 'extend-precision' of https://github.com/aagallo/Transfo…
aagallo Feb 10, 2026
ec31f2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
07a87a7
Fix params_dtype to use computed dtype from precision logic
aagallo Feb 10, 2026
748bb39
Merge branch 'extend-precision' of https://github.com/aagallo/Transfo…
aagallo Feb 10, 2026
c6fb3a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
da9f82b
Merge branch 'main' into extend-precision
aagallo Feb 10, 2026
2c5473e
Fix type conversion in StoreExplicitAction for --dtype argument
aagallo Feb 10, 2026
a9e664c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
196df3d
Fix precision preset recipe selection and add incompatibility validation
aagallo Feb 11, 2026
368820b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
9e2f34b
Merge branch 'main' into extend-precision
aagallo Feb 11, 2026
76dcb94
Fix unreachable default case and redundant recipe recreation
aagallo Feb 11, 2026
e22c2f2
Add explicit error handling for invalid precision presets
aagallo Feb 11, 2026
9d637d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 199 additions & 12 deletions examples/pytorch/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common.recipe import (
Format,
DelayedScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
)
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -68,6 +73,13 @@ def torch_dtype(d):
return typemap[lowercase(d)]


def precision(d):
typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"]
if lowercase(d) not in typemap:
raise TypeError
return lowercase(d)


te_layer_map = {
"linear": te.Linear,
"layernorm": te.LayerNorm,
Expand All @@ -91,7 +103,7 @@ def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim
layer_args = (hidden_size,)
layer_kwargs = {
"params_dtype": opts.dtype,
# "params_dtype": opts.dtype,
"device": "cuda" if opts.no_defer_init else "meta",
"get_rng_state_tracker": get_cuda_rng_tracker,
}
Expand All @@ -112,6 +124,38 @@ def get_layer_args(opts):
return layer_args, layer_kwargs


class StoreExplicitAction(argparse.Action):
"""Custom action that tracks whether an argument was explicitly set."""

def __init__(self, option_strings, dest, type=None, **kwargs):
super().__init__(option_strings, dest, **kwargs)
self.type_converter = type # Store the type converter

def __call__(self, parser, namespace, values, option_string=None):
# Apply the type converter if one was provided
if self.type_converter is not None:
try:
values = self.type_converter(values)
except (ValueError, TypeError) as e:
raise argparse.ArgumentTypeError(f"invalid {self.dest} value: {values}") from e

setattr(namespace, self.dest, values)
setattr(namespace, f"{self.dest}_explicitly_set", True)


class StoreTrueExplicitAction(argparse.Action):
"""Custom action for store_true that tracks whether flag was explicitly set."""

def __init__(self, option_strings, dest, default=False, required=False, help=None):
super().__init__(
option_strings, dest, nargs=0, const=True, default=default, required=required, help=help
)

def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, True)
setattr(namespace, f"{self.dest}_explicitly_set", True)


def parse_fsdp_args():
parser = argparse.ArgumentParser(
description="Run Transformer Engine modules with the "
Expand Down Expand Up @@ -171,9 +215,21 @@ def parse_fsdp_args():
)
parser.add_argument(
"--no-fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250

Suggested change
"--no-fp8",
action=StoreExplicitAction,

action="store_true",
action=StoreTrueExplicitAction,
default=False,
help="Disables the te.autocast() context.",
help=(
"Disables the te.autocast() context. When set, FP8 training is disabled and the model"
" trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is"
" incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:"
" Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but"
" harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR"
" (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so"
" disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead"
" for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)."
" '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is"
" the correct way to request non-FP8 training. Default: False (FP8 enabled based on"
" configuration)."
),
)
parser.add_argument(
"--no-defer-init",
Expand All @@ -189,7 +245,35 @@ def parse_fsdp_args():
"--dtype",
type=torch_dtype,
default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.",
action=StoreExplicitAction,
help=(
"Data type for input tensor and Transformer Engine module parameters. Supported values:"
" fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag"
" overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:"
" Controls parameter dtype directly - With --precision: Overrides preset's dtype but"
" preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters"
" (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with"
" MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype."
" Default: bfloat16."
),
)
parser.add_argument(
"--precision",
type=precision,
default=None,
help=(
"Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4."
" This is a convenience flag that configures both dtype and FP8 settings automatically."
" - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with"
" DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling"
" recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16"
" parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype"
" from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an"
" error is raised (incompatible) - If this flag is not set, original behavior is used"
" (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8"
" FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16"
" parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)."
),
)
return parser.parse_args()
Comment on lines 260 to 278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).


Expand All @@ -200,15 +284,122 @@ def dist_print(text, all_ranks=False, no_new_line=False):
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)


def get_precision_preset(precision_value):
"""Get dtype, no_fp8, and recipe based on precision preset.

Returns:
tuple: (dtype, no_fp8, recipe)
"""
match precision_value:
case "fp32":
return torch.float32, True, None
case "fp16":
return torch.float16, True, None
case "fp8":
recipe = DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
return torch.bfloat16, False, recipe
case "mxfp8":
recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
return torch.bfloat16, False, recipe
case "nvfp4":
recipe = NVFP4BlockScaling()
return torch.bfloat16, False, recipe
case _:
# Fail loudly if validation is bypassed or new preset added without updating this function
raise ValueError(
f"Invalid precision preset: {precision_value}. "
"Supported values: fp32, fp16, fp8, mxfp8, nvfp4"
)


def get_recipe_for_precision(precision_value):
"""Get FP8 recipe based on precision preset (when FP8 is enabled).

Args:
precision_value: The precision preset string

Returns:
Recipe object for FP8 training
"""
match precision_value:
case "mxfp8":
return MXFP8BlockScaling(fp8_format=Format.E4M3)
case "nvfp4":
return NVFP4BlockScaling()
case _:
# Default to DelayedScaling for fp8 or any other value
return DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)


def train(opts):
# Check which flags were explicitly set
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False)

# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(LOCAL_RANK)
dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
torch.manual_seed(opts.seed)

# Determine final configuration based on precedence rules
if opts.precision is None:
# Case 1: Backward compatibility - no precision preset specified
# Use original behavior with dtype and no_fp8 flags
dtype = opts.dtype
no_fp8 = opts.no_fp8

# Set up recipe if FP8 is enabled
if not no_fp8:
recipe = DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
else:
# Case 2: Precision preset was explicitly specified
# Start with precision preset values
preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)

# Check for incompatible flag combinations
# Error if user requests FP8-based precision but also sets --no-fp8
if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8:
raise ValueError(
f"Cannot use --no-fp8 with --precision {opts.precision}. "
"These flags are incompatible. "
f"Either remove --no-fp8 to use {opts.precision} training, "
"or use --precision fp32/fp16 for non-FP8 training."
)

dtype = preset_dtype
no_fp8 = preset_no_fp8
recipe = preset_recipe

dist_print(f"Using precision preset: {opts.precision}")

# Apply explicit dtype override with warning
if dtype_explicitly_set:
dtype = opts.dtype
dist_print(
f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting"
)

# If FP8 is still enabled, keep recipe based on precision
# (dtype only affects parameter storage, not FP8 recipe)
if not no_fp8:
recipe = get_recipe_for_precision(opts.precision)
Comment on lines +386 to +394
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant recipe recreation

When dtype_explicitly_set is true and FP8 remains enabled, the code overwrites recipe with get_recipe_for_precision(opts.precision), even though recipe was already set from the selected preset earlier. This creates a second recipe instance (and can diverge if preset recipe configuration changes). If the intent is “dtype override shouldn’t affect recipe”, you can keep the existing recipe rather than re-instantiating it.


# Always log the final configuration being used
dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}")

# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(opts)
layer_kwargs["params_dtype"] = dtype

if opts.num_layers > 1:
te_layer_list = []
for i in range(opts.num_layers):
Expand Down Expand Up @@ -239,7 +430,7 @@ def train(opts):
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=opts.dtype,
param_dtype=dtype,
reduce_dtype=torch.float32,
),
auto_wrap_policy=fsdp_wrap_policy,
Expand All @@ -258,10 +449,6 @@ def train(opts):
dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")

# Fp8 setup for TE
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

Expand All @@ -281,11 +468,11 @@ def train(opts):
opts.seq_length,
opts.batch_size,
opts.num_heads * opts.head_dim,
dtype=opts.dtype,
dtype=dtype,
device="cuda",
)
# autocast needs to be given the FSDP process group for amax reductions
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the autocast context
Expand Down