diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b469ef56b7..0da5b265d2 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,12 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -68,6 +73,13 @@ def torch_dtype(d): return typemap[lowercase(d)] +def precision(d): + typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] + if lowercase(d) not in typemap: + raise TypeError + return lowercase(d) + + te_layer_map = { "linear": te.Linear, "layernorm": te.LayerNorm, @@ -91,7 +103,7 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - "params_dtype": opts.dtype, + # "params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -112,6 +124,38 @@ def get_layer_args(opts): return layer_args, layer_kwargs +class StoreExplicitAction(argparse.Action): + """Custom action that tracks whether an argument was explicitly set.""" + + def __init__(self, option_strings, dest, type=None, **kwargs): + super().__init__(option_strings, dest, **kwargs) + self.type_converter = type # Store the type converter + + def __call__(self, parser, namespace, values, option_string=None): + # Apply the type converter if one was provided + if self.type_converter is not None: + try: + values = self.type_converter(values) + except (ValueError, TypeError) as e: + raise argparse.ArgumentTypeError(f"invalid {self.dest} value: {values}") from e + + setattr(namespace, self.dest, values) + setattr(namespace, f"{self.dest}_explicitly_set", True) + + +class StoreTrueExplicitAction(argparse.Action): + """Custom action for store_true that tracks whether flag was explicitly set.""" + + def __init__(self, option_strings, dest, default=False, required=False, help=None): + super().__init__( + option_strings, dest, nargs=0, const=True, default=default, required=required, help=help + ) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, True) + setattr(namespace, f"{self.dest}_explicitly_set", True) + + def parse_fsdp_args(): parser = argparse.ArgumentParser( description="Run Transformer Engine modules with the " @@ -171,9 +215,21 @@ def parse_fsdp_args(): ) parser.add_argument( "--no-fp8", - action="store_true", + action=StoreTrueExplicitAction, default=False, - help="Disables the te.autocast() context.", + help=( + "Disables the te.autocast() context. When set, FP8 training is disabled and the model" + " trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is" + " incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:" + " Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but" + " harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR" + " (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so" + " disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead" + " for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)." + " '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is" + " the correct way to request non-FP8 training. Default: False (FP8 enabled based on" + " configuration)." + ), ) parser.add_argument( "--no-defer-init", @@ -189,7 +245,35 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - help="Data type for input tensor and Transformer Engine module parameters.", + action=StoreExplicitAction, + help=( + "Data type for input tensor and Transformer Engine module parameters. Supported values:" + " fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag" + " overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:" + " Controls parameter dtype directly - With --precision: Overrides preset's dtype but" + " preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters" + " (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with" + " MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype." + " Default: bfloat16." + ), + ) + parser.add_argument( + "--precision", + type=precision, + default=None, + help=( + "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4." + " This is a convenience flag that configures both dtype and FP8 settings automatically." + " - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with" + " DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling" + " recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16" + " parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype" + " from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an" + " error is raised (incompatible) - If this flag is not set, original behavior is used" + " (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8" + " FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16" + " parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)." + ), ) return parser.parse_args() @@ -200,15 +284,122 @@ def dist_print(text, all_ranks=False, no_new_line=False): print(f"[GPU-{LOCAL_RANK}] " + text, end=end) +def get_precision_preset(precision_value): + """Get dtype, no_fp8, and recipe based on precision preset. + + Returns: + tuple: (dtype, no_fp8, recipe) + """ + match precision_value: + case "fp32": + return torch.float32, True, None + case "fp16": + return torch.float16, True, None + case "fp8": + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + case "mxfp8": + recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + return torch.bfloat16, False, recipe + case "nvfp4": + recipe = NVFP4BlockScaling() + return torch.bfloat16, False, recipe + case _: + # Fail loudly if validation is bypassed or new preset added without updating this function + raise ValueError( + f"Invalid precision preset: {precision_value}. " + "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + ) + + +def get_recipe_for_precision(precision_value): + """Get FP8 recipe based on precision preset (when FP8 is enabled). + + Args: + precision_value: The precision preset string + + Returns: + Recipe object for FP8 training + """ + match precision_value: + case "mxfp8": + return MXFP8BlockScaling(fp8_format=Format.E4M3) + case "nvfp4": + return NVFP4BlockScaling() + case _: + # Default to DelayedScaling for fp8 or any other value + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + + def train(opts): + # Check which flags were explicitly set + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) + no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) + # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") torch.cuda.set_device(LOCAL_RANK) dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) + # Determine final configuration based on precedence rules + if opts.precision is None: + # Case 1: Backward compatibility - no precision preset specified + # Use original behavior with dtype and no_fp8 flags + dtype = opts.dtype + no_fp8 = opts.no_fp8 + + # Set up recipe if FP8 is enabled + if not no_fp8: + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + else: + recipe = None + else: + # Case 2: Precision preset was explicitly specified + # Start with precision preset values + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + + # Check for incompatible flag combinations + # Error if user requests FP8-based precision but also sets --no-fp8 + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: + raise ValueError( + f"Cannot use --no-fp8 with --precision {opts.precision}. " + "These flags are incompatible. " + f"Either remove --no-fp8 to use {opts.precision} training, " + "or use --precision fp32/fp16 for non-FP8 training." + ) + + dtype = preset_dtype + no_fp8 = preset_no_fp8 + recipe = preset_recipe + + dist_print(f"Using precision preset: {opts.precision}") + + # Apply explicit dtype override with warning + if dtype_explicitly_set: + dtype = opts.dtype + dist_print( + f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" + ) + + # If FP8 is still enabled, keep recipe based on precision + # (dtype only affects parameter storage, not FP8 recipe) + if not no_fp8: + recipe = get_recipe_for_precision(opts.precision) + + # Always log the final configuration being used + dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) + layer_kwargs["params_dtype"] = dtype + if opts.num_layers > 1: te_layer_list = [] for i in range(opts.num_layers): @@ -239,7 +430,7 @@ def train(opts): process_group=all_gpus, use_orig_params=True, mixed_precision=MixedPrecision( - param_dtype=opts.dtype, + param_dtype=dtype, reduce_dtype=torch.float32, ), auto_wrap_policy=fsdp_wrap_policy, @@ -258,10 +449,6 @@ def train(opts): dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") - # Fp8 setup for TE - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) @@ -281,11 +468,11 @@ def train(opts): opts.seq_length, opts.batch_size, opts.num_heads * opts.head_dim, - dtype=opts.dtype, + dtype=dtype, device="cuda", ) # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context