From c960e6dc515b0d663ea904c7b6478b86f9f3d5d8 Mon Sep 17 00:00:00 2001 From: aagallo Date: Mon, 9 Feb 2026 10:03:06 -0500 Subject: [PATCH 01/18] Add precision parameter support for multiple training formats Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo --- examples/pytorch/fsdp/fsdp.py | 65 +++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b469ef56b7..2b8dc3a436 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,7 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -68,6 +68,19 @@ def torch_dtype(d): return typemap[lowercase(d)] +def precision(d): + typemap = [ + "fp32", + "fp16", + "fp8", + "mxfp8", + "nvfp4" + ] + if lowercase(d) not in typemap: + raise TypeError + return lowercase(d) + + te_layer_map = { "linear": te.Linear, "layernorm": te.LayerNorm, @@ -191,6 +204,12 @@ def parse_fsdp_args(): default=torch.bfloat16, help="Data type for input tensor and Transformer Engine module parameters.", ) + parser.add_argument( + "--precision", + type=precision, + default="fp8", + help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", + ) return parser.parse_args() @@ -209,6 +228,42 @@ def train(opts): # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) + + # Determining the format and recipe for the training + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = opts.no_fp8 + dtype=opts.dtype + + match opts.precision: + case "fp32": + dtype=torch.float32 + no_fp8 = True + case "fp16": + dtype=torch.bfloat16 + no_fp8 = True + case "fp8": + dtype=torch.bfloat16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = False + case "mxfp8": + dtype=torch.bfloat16 + precision_format = Format.E4M3 + recipe = MXFP8BlockScaling(fp8_format=precision_format) + no_fp8 = False + case "nvfp4": + dtype=torch.bfloat16 # RHT only supports bfloat16 + recipe = NVFP4BlockScaling() + no_fp8 = False + case _: + dtype=torch.bfloat16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = opts.no_fp8 + + layer_kwargs["params_dtype"]=dtype + if opts.num_layers > 1: te_layer_list = [] for i in range(opts.num_layers): @@ -258,10 +313,6 @@ def train(opts): dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") - # Fp8 setup for TE - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) @@ -281,11 +332,11 @@ def train(opts): opts.seq_length, opts.batch_size, opts.num_heads * opts.head_dim, - dtype=opts.dtype, + dtype=dtype, device="cuda", ) # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context From cd9784307abf0b2b1a1dafd93c9e09905ad0017e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 21:40:44 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 43 +++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 2b8dc3a436..882e76ce4b 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,12 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -69,13 +74,7 @@ def torch_dtype(d): def precision(d): - typemap = [ - "fp32", - "fp16", - "fp8", - "mxfp8", - "nvfp4" - ] + typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] if lowercase(d) not in typemap: raise TypeError return lowercase(d) @@ -231,38 +230,44 @@ def train(opts): # Determining the format and recipe for the training precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = opts.no_fp8 - dtype=opts.dtype + dtype = opts.dtype match opts.precision: case "fp32": - dtype=torch.float32 + dtype = torch.float32 no_fp8 = True case "fp16": - dtype=torch.bfloat16 + dtype = torch.bfloat16 no_fp8 = True case "fp8": - dtype=torch.bfloat16 + dtype = torch.bfloat16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = False case "mxfp8": - dtype=torch.bfloat16 + dtype = torch.bfloat16 precision_format = Format.E4M3 recipe = MXFP8BlockScaling(fp8_format=precision_format) no_fp8 = False case "nvfp4": - dtype=torch.bfloat16 # RHT only supports bfloat16 + dtype = torch.bfloat16 # RHT only supports bfloat16 recipe = NVFP4BlockScaling() no_fp8 = False case _: - dtype=torch.bfloat16 + dtype = torch.bfloat16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = opts.no_fp8 - layer_kwargs["params_dtype"]=dtype + layer_kwargs["params_dtype"] = dtype if opts.num_layers > 1: te_layer_list = [] From 40919d8949df4cc2b9ae440bcc8433dd4577e93e Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 13:19:40 -0500 Subject: [PATCH 03/18] Fix FP16 dtype mapping and implement CLI flag precedence Correct FP16 precision to use torch.float16 instead of torch.bfloat16, and add precedence logic where --dtype and --no-fp8 flags override --precision when explicitly set, with warnings issued for conflicts. - Fix case fp16 to use torch.float16 instead of torch.bfloat16 - Add flag precedence detection by comparing against default values - Implement warning messages when --dtype or --no-fp8 override --precision - Update argument parser help text to document precedence behavior - Ensure --dtype and --no-fp8 take precedence over --precision presets Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 2b8dc3a436..f38e04207d 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -220,6 +220,12 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): + + # Check which flags were explicitly set + dtype_explicitly_set = opts.dtype != torch.bfloat16 + no_fp8_explicitly_set = opts.no_fp8 != False + precision_is_non_default = opts.precision != "fp8" + # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") torch.cuda.set_device(LOCAL_RANK) From 5c1db12c3b34e8d836be0285ed33a4f5b449c308 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 15:31:35 -0500 Subject: [PATCH 04/18] Add logging and documentation for precision configuration Add informative log messages and enhanced help text to clarify precision configuration behavior and flag precedence for better user transparency. - Add log message showing which precision preset is being used - Add warning logs when --dtype or --no-fp8 override --precision - Add final training configuration log (dtype, FP8 status, recipe) - Enhance argument parser help text with precedence examples - Add inline code comments explaining precedence logic Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 103 +++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index f38e04207d..78a4fa2115 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -186,7 +186,11 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disables the te.autocast() context.", + help="Disables the te.autocast() context. When set, FP8 training is disabled " + + "and the model trains in standard precision (as specified by --dtype). " + + "Takes precedence over --precision if both are specified. " + + "Example: '--no-fp8 --precision fp8' will disable FP8 despite fp8 preset. " + + "Default: False (FP8 enabled based on precision).", ) parser.add_argument( "--no-defer-init", @@ -202,13 +206,23 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - help="Data type for input tensor and Transformer Engine module parameters.", + help="Data type for input tensor and Transformer Engine module parameters. " + + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " + + "Takes precedence over --precision if both are specified. " + + "Example: '--dtype fp16 --precision fp8' will use fp16 dtype and ignore fp8 preset. " + + "Default: bfloat16.", ) parser.add_argument( "--precision", type=precision, default="fp8", - help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", + help="Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " + + "This is a convenience flag that configures both dtype and FP8 settings automatically. " + + "If --dtype or --no-fp8 are explicitly specified, they take precedence over this flag " + + "and a warning will be issued. " + + "Precedence: --dtype and --no-fp8 override --precision. " + + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit control. " + + "Default: fp8.", ) return parser.parse_args() @@ -235,39 +249,60 @@ def train(opts): # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) - # Determining the format and recipe for the training - precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = opts.no_fp8 - dtype=opts.dtype - - match opts.precision: - case "fp32": - dtype=torch.float32 - no_fp8 = True - case "fp16": - dtype=torch.bfloat16 - no_fp8 = True - case "fp8": - dtype=torch.bfloat16 - precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = False - case "mxfp8": - dtype=torch.bfloat16 - precision_format = Format.E4M3 - recipe = MXFP8BlockScaling(fp8_format=precision_format) - no_fp8 = False - case "nvfp4": - dtype=torch.bfloat16 # RHT only supports bfloat16 - recipe = NVFP4BlockScaling() - no_fp8 = False - case _: - dtype=torch.bfloat16 + if not dtype_explicitly_set and not no_fp8_explicitly_set: + + dist_print(f"Using precision preset: {opts.precision}") + + match opts.precision: + case "fp32": + dtype=torch.float32 + no_fp8 = True + case "fp16": + dtype=torch.float16 + no_fp8 = True + case "fp8": + dtype=torch.float16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = False + case "mxfp8": + dtype=torch.float16 + precision_format = Format.E4M3 + recipe = MXFP8BlockScaling(fp8_format=precision_format) + no_fp8 = False + case "nvfp4": + dtype=torch.bfloat16 # RHT only supports bfloat16 + recipe = NVFP4BlockScaling() + no_fp8 = False + case _: + dtype=torch.float16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = opts.no_fp8 + else: + # dtype and/or no_fp8 were explicitly set - they take precedence + dtype = opts.dtype + no_fp8 = opts.no_fp8 + + # Set up default recipe for FP8 cases + if not no_fp8: precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = opts.no_fp8 - + else: + recipe = None + + # Warn if precision was also set to non-default (being overridden) + if precision_is_non_default: + if dtype_explicitly_set: + dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision}") + if no_fp8_explicitly_set: + dist_print(f"Warning: --no-fp8 overrides --precision {opts.precision}") + + # Always log the final configuration being used + dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") + if not no_fp8: + dist_print(f"Using FP8 recipe: {type(recipe).__name__}") + layer_kwargs["params_dtype"]=dtype if opts.num_layers > 1: From e4846c8a48aaea2a1fcf8c41cb8e34aa12597a27 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 15:40:23 -0500 Subject: [PATCH 05/18] Initialize recipe variable in all precision cases Add recipe initialization for fp32 and fp16 precision cases to prevent undefined variable errors, even though recipe is not used when no_fp8 is set to True. - Add DelayedScaling recipe setup for fp32 case with no_fp8=True - Add DelayedScaling recipe setup for fp16 case with no_fp8=True - Add inline comments explaining recipe is set up but not used by autocast - Ensure recipe variable is defined in all precision branches for consistency Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 78a4fa2115..b28b3d83f4 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -256,9 +256,19 @@ def train(opts): match opts.precision: case "fp32": dtype=torch.float32 + + #set up, but not used by autocast with no-fp8 set to true + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = True case "fp16": dtype=torch.float16 + + #set up, but not used by autocast with no-fp8 set to true + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = True case "fp8": dtype=torch.float16 From 26aee2f25a85e0534dd778a8021a4e39a9d27c0f Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 15:48:42 -0500 Subject: [PATCH 06/18] Fix dtype flag detection to support explicit override behavior Update flag precedence detection to use sys.argv for checking if --dtype was explicitly set, ensuring dtype always overrides precision regardless of whether it matches the default value. - Add sys import for command-line argument detection - Change dtype_explicitly_set check to use '--dtype' in sys.argv - Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv - Ensure --dtype bf16 correctly overrides --precision even when matching default - Maintain warning messages when explicit flags override precision presets Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b28b3d83f4..5b5e4b80ea 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -234,10 +234,11 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): + import sys # Check which flags were explicitly set - dtype_explicitly_set = opts.dtype != torch.bfloat16 - no_fp8_explicitly_set = opts.no_fp8 != False + dtype_explicitly_set = '--dtype' in sys.argv + no_fp8_explicitly_set = '--no-fp8' in sys.argv precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -264,7 +265,7 @@ def train(opts): no_fp8 = True case "fp16": dtype=torch.float16 - + #set up, but not used by autocast with no-fp8 set to true precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") From c9524e38dbf9ce85a6c49a41eda66f753f601569 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:03:02 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 53 ++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b8962e48f9..ae4f00c6cd 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -215,12 +215,15 @@ def parse_fsdp_args(): "--precision", type=precision, default="fp8", - help="Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " + help=( + "Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " + ) + "This is a convenience flag that configures both dtype and FP8 settings automatically. " + "If --dtype or --no-fp8 are explicitly specified, they take precedence over this flag " + "and a warning will be issued. " + "Precedence: --dtype and --no-fp8 override --precision. " - + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit control. " + + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit" + " control. " + "Default: fp8.", ) return parser.parse_args() @@ -236,8 +239,8 @@ def train(opts): import sys # Check which flags were explicitly set - dtype_explicitly_set = '--dtype' in sys.argv - no_fp8_explicitly_set = '--no-fp8' in sys.argv + dtype_explicitly_set = "--dtype" in sys.argv + no_fp8_explicitly_set = "--no-fp8" in sys.argv precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -255,39 +258,47 @@ def train(opts): match opts.precision: case "fp32": - dtype=torch.float32 + dtype = torch.float32 - #set up, but not used by autocast with no-fp8 set to true + # set up, but not used by autocast with no-fp8 set to true precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = True case "fp16": - dtype=torch.float16 + dtype = torch.float16 - #set up, but not used by autocast with no-fp8 set to true + # set up, but not used by autocast with no-fp8 set to true precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = True case "fp8": - dtype=torch.float16 + dtype = torch.float16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = False case "mxfp8": - dtype=torch.float16 + dtype = torch.float16 precision_format = Format.E4M3 recipe = MXFP8BlockScaling(fp8_format=precision_format) no_fp8 = False case "nvfp4": - dtype=torch.bfloat16 # RHT only supports bfloat16 + dtype = torch.bfloat16 # RHT only supports bfloat16 recipe = NVFP4BlockScaling() no_fp8 = False case _: - dtype=torch.float16 + dtype = torch.float16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = opts.no_fp8 else: # dtype and/or no_fp8 were explicitly set - they take precedence @@ -297,11 +308,13 @@ def train(opts): # Set up default recipe for FP8 cases if not no_fp8: precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) else: recipe = None - # Warn if precision was also set to non-default (being overridden) + # Warn if precision was also set to non-default (being overridden) if precision_is_non_default: if dtype_explicitly_set: dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision}") @@ -312,8 +325,8 @@ def train(opts): dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") if not no_fp8: dist_print(f"Using FP8 recipe: {type(recipe).__name__}") - - layer_kwargs["params_dtype"]=dtype + + layer_kwargs["params_dtype"] = dtype if opts.num_layers > 1: te_layer_list = [] From a506d402ed6e78547d59b096f57f90539840a872 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 16:15:36 -0500 Subject: [PATCH 08/18] Replace sys.argv parsing with custom action and fix default case Replace fragile sys.argv parsing with robust custom argparse action class to track explicitly set arguments, and fix default precision case to explicitly set no_fp8 to False for consistent FP8-enabled behavior. - Add StoreExplicitAction custom action class for tracking explicit arguments - Update --dtype argument to use StoreExplicitAction - Replace sys.argv check with getattr for dtype_explicitly_set attribute - Remove sys import from train() function - Fix default case to set no_fp8 = False instead of opts.no_fp8 - Ensure recipe variable is properly initialized in all code paths - Support all argument passing methods including config files and = syntax Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b8962e48f9..13185a367e 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -124,6 +124,13 @@ def get_layer_args(opts): return layer_args, layer_kwargs +class StoreExplicitAction(argparse.Action): + """Custom action that tracks whether an argument was explicitly set.""" + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + setattr(namespace, f'{self.dest}_explicitly_set', True) + + def parse_fsdp_args(): parser = argparse.ArgumentParser( description="Run Transformer Engine modules with the " @@ -205,6 +212,7 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, + action=StoreExplicitAction, # Add custom action help="Data type for input tensor and Transformer Engine module parameters. " + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " + "Takes precedence over --precision if both are specified. " @@ -233,11 +241,9 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): - import sys - # Check which flags were explicitly set - dtype_explicitly_set = '--dtype' in sys.argv - no_fp8_explicitly_set = '--no-fp8' in sys.argv + dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) + no_fp8_explicitly_set = opts.no_fp8 != False precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -288,7 +294,7 @@ def train(opts): dtype=torch.float16 precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = opts.no_fp8 + no_fp8 = False else: # dtype and/or no_fp8 were explicitly set - they take precedence dtype = opts.dtype From ec31f2a736f6e4f78a92080aa0a491c1abba5acb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:18:47 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 465d918663..45accb47ee 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -126,9 +126,10 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" + def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) - setattr(namespace, f'{self.dest}_explicitly_set', True) + setattr(namespace, f"{self.dest}_explicitly_set", True) def parse_fsdp_args(): @@ -245,7 +246,7 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): # Check which flags were explicitly set - dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) no_fp8_explicitly_set = opts.no_fp8 != False precision_is_non_default = opts.precision != "fp8" @@ -302,7 +303,9 @@ def train(opts): case _: dtype = torch.float16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = False else: # dtype and/or no_fp8 were explicitly set - they take precedence From 07a87a7891dcb42be62109ae608ad3a08b6c3ac2 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 16:47:43 -0500 Subject: [PATCH 10/18] Fix params_dtype to use computed dtype from precision logic Remove params_dtype initialization from get_layer_args() and update FSDP MixedPrecision to use computed dtype variable instead of raw opts.dtype, ensuring precision presets are properly applied throughout the model. - Remove params_dtype from get_layer_args() layer_kwargs initialization - Update FSDP MixedPrecision param_dtype to use computed dtype variable - Ensure precision preset logic is respected in both layer initialization and FSDP - Maintain backward compatibility with original FP8-enabled default behavior Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 465d918663..65ac782f54 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -103,7 +103,7 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - "params_dtype": opts.dtype, + #"params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -130,6 +130,15 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) setattr(namespace, f'{self.dest}_explicitly_set', True) +class StoreTrueExplicitAction(argparse.Action): + """Custom action for store_true that tracks whether flag was explicitly set.""" + def __init__(self, option_strings, dest, default=False, required=False, help=None): + super().__init__(option_strings, dest, nargs=0, const=True, + default=default, required=required, help=help) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, True) + setattr(namespace, f'{self.dest}_explicitly_set', True) def parse_fsdp_args(): parser = argparse.ArgumentParser( @@ -190,7 +199,7 @@ def parse_fsdp_args(): ) parser.add_argument( "--no-fp8", - action="store_true", + action=StoreTrueExplicitAction, # Use custom action default=False, help="Disables the te.autocast() context. When set, FP8 training is disabled " + "and the model trains in standard precision (as specified by --dtype). " @@ -246,7 +255,7 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) - no_fp8_explicitly_set = opts.no_fp8 != False + no_fp8_explicitly_set = getattr(opts, 'no_fp8_explicitly_set', False) # Fixed precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -284,14 +293,14 @@ def train(opts): no_fp8 = True case "fp8": - dtype = torch.float16 + dtype = torch.bfloat16 precision_format = Format.HYBRID recipe = DelayedScaling( fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" ) no_fp8 = False case "mxfp8": - dtype = torch.float16 + dtype = torch.bfloat16 precision_format = Format.E4M3 recipe = MXFP8BlockScaling(fp8_format=precision_format) no_fp8 = False @@ -300,7 +309,7 @@ def train(opts): recipe = NVFP4BlockScaling() no_fp8 = False case _: - dtype = torch.float16 + dtype = torch.bfloat16 precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") no_fp8 = False @@ -327,7 +336,7 @@ def train(opts): # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") - if not no_fp8: + if not no_fp8 and recipe is not None: dist_print(f"Using FP8 recipe: {type(recipe).__name__}") layer_kwargs["params_dtype"] = dtype @@ -362,7 +371,7 @@ def train(opts): process_group=all_gpus, use_orig_params=True, mixed_precision=MixedPrecision( - param_dtype=opts.dtype, + param_dtype=dtype, reduce_dtype=torch.float32, ), auto_wrap_policy=fsdp_wrap_policy, From c6fb3a51f7ea23370f67a657224e53aad6cd0c71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:56:21 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 6431712930..9287688616 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -103,7 +103,7 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - #"params_dtype": opts.dtype, + # "params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -131,15 +131,18 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) setattr(namespace, f"{self.dest}_explicitly_set", True) + class StoreTrueExplicitAction(argparse.Action): """Custom action for store_true that tracks whether flag was explicitly set.""" + def __init__(self, option_strings, dest, default=False, required=False, help=None): - super().__init__(option_strings, dest, nargs=0, const=True, - default=default, required=required, help=help) + super().__init__( + option_strings, dest, nargs=0, const=True, default=default, required=required, help=help + ) def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, True) - setattr(namespace, f'{self.dest}_explicitly_set', True) + setattr(namespace, f"{self.dest}_explicitly_set", True) def parse_fsdp_args(): @@ -256,8 +259,8 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): # Check which flags were explicitly set - dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) - no_fp8_explicitly_set = getattr(opts, 'no_fp8_explicitly_set', False) # Fixed + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) + no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Fixed precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group From 2c5473e92e6c0df93bf349426a43d7ea747c39c5 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 17:09:52 -0500 Subject: [PATCH 12/18] Fix type conversion in StoreExplicitAction for --dtype argument Add type converter application in StoreExplicitAction custom action to ensure --dtype values are properly converted from strings to torch dtype objects, preventing runtime errors in torch operations. - Store type converter in StoreExplicitAction.__init__ - Apply type conversion in __call__ before setting attribute value - Add error handling for invalid type conversions - Ensure opts.dtype contains torch dtype object, not raw string - Fix runtime errors in torch.rand() and MixedPrecision() calls Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 9287688616..0fe2e198cb 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -126,10 +126,22 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" + def __init__(self, option_strings, dest, type=None, **kwargs): + super().__init__(option_strings, dest, **kwargs) + self.type_converter = type # Store the type converter def __call__(self, parser, namespace, values, option_string=None): + # Apply the type converter if one was provided + if self.type_converter is not None: + try: + values = self.type_converter(values) + except (ValueError, TypeError) as e: + raise argparse.ArgumentTypeError( + f"invalid {self.dest} value: {values}" + ) from e + setattr(namespace, self.dest, values) - setattr(namespace, f"{self.dest}_explicitly_set", True) + setattr(namespace, f'{self.dest}_explicitly_set', True) class StoreTrueExplicitAction(argparse.Action): From a9e664c4d4d7d1cc57ff5e59a4c0838726e4f6cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:11:04 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 0fe2e198cb..c1a500815a 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -126,6 +126,7 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" + def __init__(self, option_strings, dest, type=None, **kwargs): super().__init__(option_strings, dest, **kwargs) self.type_converter = type # Store the type converter @@ -136,12 +137,10 @@ def __call__(self, parser, namespace, values, option_string=None): try: values = self.type_converter(values) except (ValueError, TypeError) as e: - raise argparse.ArgumentTypeError( - f"invalid {self.dest} value: {values}" - ) from e + raise argparse.ArgumentTypeError(f"invalid {self.dest} value: {values}") from e setattr(namespace, self.dest, values) - setattr(namespace, f'{self.dest}_explicitly_set', True) + setattr(namespace, f"{self.dest}_explicitly_set", True) class StoreTrueExplicitAction(argparse.Action): From 196df3df94c07e230e152a0d1f4db37fa148e987 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 22:45:46 -0500 Subject: [PATCH 14/18] Fix precision preset recipe selection and add incompatibility validation Address critical bugs where FP8 recipes were incorrectly selected when explicit flags were set, and add validation to prevent incompatible flag combinations that would silently disable FP8 training. - Remove default value from --precision parameter (set to None for backward compatibility) - Add get_precision_preset() and get_recipe_for_precision() helper functions - Implement two-path configuration logic: backward compatibility mode vs. precision preset mode - Add incompatibility validation: raise ValueError when --no-fp8 used with fp8/mxfp8/nvfp4 presets - Preserve FP8 recipe selection when --dtype explicitly overrides precision preset dtype - Fix fp16 case to correctly map to torch.float16 instead of torch.bfloat16 - Update parameter help text with precedence rules and usage examples - Ensure backward compatibility: scripts without --precision work identically to original version Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 228 ++++++++++++++++++++-------------- 1 file changed, 138 insertions(+), 90 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index c1a500815a..fa68b72840 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -215,13 +215,24 @@ def parse_fsdp_args(): ) parser.add_argument( "--no-fp8", - action=StoreTrueExplicitAction, # Use custom action + action=StoreTrueExplicitAction, default=False, - help="Disables the te.autocast() context. When set, FP8 training is disabled " - + "and the model trains in standard precision (as specified by --dtype). " - + "Takes precedence over --precision if both are specified. " - + "Example: '--no-fp8 --precision fp8' will disable FP8 despite fp8 preset. " - + "Default: False (FP8 enabled based on precision).", + help=( + "Disables the te.autocast() context. When set, FP8 training is disabled " + "and the model trains in standard precision (as specified by --dtype). " + "PRECEDENCE: This flag is incompatible with FP8-based --precision presets. " + "BEHAVIOR: " + "- Without --precision: Disables FP8 training (original behavior) " + "- With --precision fp32/fp16: Redundant but harmless (already non-FP8) " + "- With --precision fp8/mxfp8/nvfp4: RAISES ERROR (incompatible flags) " + "RATIONALE: FP8 presets explicitly request FP8 training, so disabling FP8 " + "would contradict the user's intent. Use --precision fp32/fp16 instead for non-FP8 training. " + "EXAMPLES: " + "'--no-fp8' disables FP8 (original behavior). " + "'--precision fp8 --no-fp8' raises ValueError (incompatible). " + "'--precision fp16' is the correct way to request non-FP8 training. " + "Default: False (FP8 enabled based on configuration)." + ), ) parser.add_argument( "--no-defer-init", @@ -237,27 +248,41 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - action=StoreExplicitAction, # Add custom action - help="Data type for input tensor and Transformer Engine module parameters. " - + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " - + "Takes precedence over --precision if both are specified. " - + "Example: '--dtype fp16 --precision fp8' will use fp16 dtype and ignore fp8 preset. " - + "Default: bfloat16.", + action=StoreExplicitAction, + help=( + "Data type for input tensor and Transformer Engine module parameters. " + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " + "PRECEDENCE: When explicitly set, this flag overrides the dtype from --precision preset. " + "BEHAVIOR: " + "- Without --precision: Controls parameter dtype directly " + "- With --precision: Overrides preset's dtype but preserves FP8 recipe selection " + "EXAMPLES: " + "'--dtype bf16' uses bfloat16 parameters (original behavior). " + "'--precision mxfp8 --dtype fp16' uses fp16 parameters with MXFP8BlockScaling recipe. " + "A warning is issued when overriding --precision dtype. " + "Default: bfloat16." + ), ) parser.add_argument( "--precision", type=precision, - default="fp8", + default=None, help=( - "Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " - ) - + "This is a convenience flag that configures both dtype and FP8 settings automatically. " - + "If --dtype or --no-fp8 are explicitly specified, they take precedence over this flag " - + "and a warning will be issued. " - + "Precedence: --dtype and --no-fp8 override --precision. " - + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit" - " control. " - + "Default: fp8.", + "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4. " + "This is a convenience flag that configures both dtype and FP8 settings automatically. " + "- fp32/fp16: Non-FP8 training with specified dtype " + "- fp8: FP8 training with DelayedScaling recipe (bf16 parameters) " + "- mxfp8: FP8 training with MXFP8BlockScaling recipe (bf16 parameters) " + "- nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16 parameters) " + "PRECEDENCE RULES: " + "- If --dtype is explicitly set, it overrides the dtype from this preset (with warning) " + "- If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an error is raised (incompatible) " + "- If this flag is not set, original behavior is used (--dtype and --no-fp8 control training) " + "EXAMPLES: " + "'--precision mxfp8' enables MXFP8 FP8 training with bf16 parameters. " + "'--precision fp8 --dtype fp16' uses fp16 parameters but keeps DelayedScaling recipe. " + "Default: None (backward compatible mode)." + ), ) return parser.parse_args() @@ -267,12 +292,60 @@ def dist_print(text, all_ranks=False, no_new_line=False): end = "" if no_new_line else "\n" print(f"[GPU-{LOCAL_RANK}] " + text, end=end) +def get_precision_preset(precision_value): + """Get dtype, no_fp8, and recipe based on precision preset. + + Returns: + tuple: (dtype, no_fp8, recipe) + """ + match precision_value: + case "fp32": + return torch.float32, True, None + case "fp16": + return torch.float16, True, None + case "fp8": + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + case "mxfp8": + recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + return torch.bfloat16, False, recipe + case "nvfp4": + recipe = NVFP4BlockScaling() + return torch.bfloat16, False, recipe + case _: + # Default to fp8 behavior + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + + +def get_recipe_for_precision(precision_value): + """Get FP8 recipe based on precision preset (when FP8 is enabled). + + Args: + precision_value: The precision preset string + + Returns: + Recipe object for FP8 training + """ + match precision_value: + case "mxfp8": + return MXFP8BlockScaling(fp8_format=Format.E4M3) + case "nvfp4": + return NVFP4BlockScaling() + case _: + # Default to DelayedScaling for fp8 or any other value + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) - no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Fixed - precision_is_non_default = opts.precision != "fp8" + no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") @@ -280,83 +353,58 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) - # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM - layer_args, layer_kwargs = get_layer_args(opts) - - if not dtype_explicitly_set and not no_fp8_explicitly_set: - - dist_print(f"Using precision preset: {opts.precision}") - - match opts.precision: - case "fp32": - dtype = torch.float32 - - # set up, but not used by autocast with no-fp8 set to true - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - - no_fp8 = True - case "fp16": - dtype = torch.float16 - - # set up, but not used by autocast with no-fp8 set to true - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - - no_fp8 = True - case "fp8": - dtype = torch.bfloat16 - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - no_fp8 = False - case "mxfp8": - dtype = torch.bfloat16 - precision_format = Format.E4M3 - recipe = MXFP8BlockScaling(fp8_format=precision_format) - no_fp8 = False - case "nvfp4": - dtype = torch.bfloat16 # RHT only supports bfloat16 - recipe = NVFP4BlockScaling() - no_fp8 = False - case _: - dtype = torch.bfloat16 - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - no_fp8 = False - else: - # dtype and/or no_fp8 were explicitly set - they take precedence + # Determine final configuration based on precedence rules + if opts.precision is None: + # Case 1: Backward compatibility - no precision preset specified + # Use original behavior with dtype and no_fp8 flags dtype = opts.dtype no_fp8 = opts.no_fp8 - - # Set up default recipe for FP8 cases + + # Set up recipe if FP8 is enabled if not no_fp8: - precision_format = Format.HYBRID recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: - recipe = None - - # Warn if precision was also set to non-default (being overridden) - if precision_is_non_default: - if dtype_explicitly_set: - dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision}") - if no_fp8_explicitly_set: - dist_print(f"Warning: --no-fp8 overrides --precision {opts.precision}") + recipe = None + else: + # Case 2: Precision preset was explicitly specified + # Start with precision preset values + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + + # Check for incompatible flag combinations + # Error if user requests FP8-based precision but also sets --no-fp8 + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: + raise ValueError( + f"Cannot use --no-fp8 with --precision {opts.precision}. " + f"These flags are incompatible. " + f"Either remove --no-fp8 to use {opts.precision} training, " + f"or use --precision fp32/fp16 for non-FP8 training." + ) + + dtype = preset_dtype + no_fp8 = preset_no_fp8 + recipe = preset_recipe + + dist_print(f"Using precision preset: {opts.precision}") + + # Apply explicit dtype override with warning + if dtype_explicitly_set: + dtype = opts.dtype + dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting") + + # If FP8 is still enabled, keep recipe based on precision + # (dtype only affects parameter storage, not FP8 recipe) + if not no_fp8: + recipe = get_recipe_for_precision(opts.precision) # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") if not no_fp8 and recipe is not None: dist_print(f"Using FP8 recipe: {type(recipe).__name__}") + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM + layer_args, layer_kwargs = get_layer_args(opts) layer_kwargs["params_dtype"] = dtype if opts.num_layers > 1: @@ -457,4 +505,4 @@ def train(opts): # torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init if __name__ == "__main__": args = parse_fsdp_args() - train(args) + train(args) \ No newline at end of file From 368820ba52063898853ffb47e91705b6d8777c4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 03:47:34 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 101 ++++++++++++++++------------------ 1 file changed, 48 insertions(+), 53 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index fa68b72840..cfa2c135ab 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -218,20 +218,17 @@ def parse_fsdp_args(): action=StoreTrueExplicitAction, default=False, help=( - "Disables the te.autocast() context. When set, FP8 training is disabled " - "and the model trains in standard precision (as specified by --dtype). " - "PRECEDENCE: This flag is incompatible with FP8-based --precision presets. " - "BEHAVIOR: " - "- Without --precision: Disables FP8 training (original behavior) " - "- With --precision fp32/fp16: Redundant but harmless (already non-FP8) " - "- With --precision fp8/mxfp8/nvfp4: RAISES ERROR (incompatible flags) " - "RATIONALE: FP8 presets explicitly request FP8 training, so disabling FP8 " - "would contradict the user's intent. Use --precision fp32/fp16 instead for non-FP8 training. " - "EXAMPLES: " - "'--no-fp8' disables FP8 (original behavior). " - "'--precision fp8 --no-fp8' raises ValueError (incompatible). " - "'--precision fp16' is the correct way to request non-FP8 training. " - "Default: False (FP8 enabled based on configuration)." + "Disables the te.autocast() context. When set, FP8 training is disabled and the model" + " trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is" + " incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:" + " Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but" + " harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR" + " (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so" + " disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead" + " for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)." + " '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is" + " the correct way to request non-FP8 training. Default: False (FP8 enabled based on" + " configuration)." ), ) parser.add_argument( @@ -250,17 +247,14 @@ def parse_fsdp_args(): default=torch.bfloat16, action=StoreExplicitAction, help=( - "Data type for input tensor and Transformer Engine module parameters. " - "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " - "PRECEDENCE: When explicitly set, this flag overrides the dtype from --precision preset. " - "BEHAVIOR: " - "- Without --precision: Controls parameter dtype directly " - "- With --precision: Overrides preset's dtype but preserves FP8 recipe selection " - "EXAMPLES: " - "'--dtype bf16' uses bfloat16 parameters (original behavior). " - "'--precision mxfp8 --dtype fp16' uses fp16 parameters with MXFP8BlockScaling recipe. " - "A warning is issued when overriding --precision dtype. " - "Default: bfloat16." + "Data type for input tensor and Transformer Engine module parameters. Supported values:" + " fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag" + " overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:" + " Controls parameter dtype directly - With --precision: Overrides preset's dtype but" + " preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters" + " (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with" + " MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype." + " Default: bfloat16." ), ) parser.add_argument( @@ -268,20 +262,17 @@ def parse_fsdp_args(): type=precision, default=None, help=( - "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4. " - "This is a convenience flag that configures both dtype and FP8 settings automatically. " - "- fp32/fp16: Non-FP8 training with specified dtype " - "- fp8: FP8 training with DelayedScaling recipe (bf16 parameters) " - "- mxfp8: FP8 training with MXFP8BlockScaling recipe (bf16 parameters) " - "- nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16 parameters) " - "PRECEDENCE RULES: " - "- If --dtype is explicitly set, it overrides the dtype from this preset (with warning) " - "- If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an error is raised (incompatible) " - "- If this flag is not set, original behavior is used (--dtype and --no-fp8 control training) " - "EXAMPLES: " - "'--precision mxfp8' enables MXFP8 FP8 training with bf16 parameters. " - "'--precision fp8 --dtype fp16' uses fp16 parameters but keeps DelayedScaling recipe. " - "Default: None (backward compatible mode)." + "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4." + " This is a convenience flag that configures both dtype and FP8 settings automatically." + " - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with" + " DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling" + " recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16" + " parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype" + " from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an" + " error is raised (incompatible) - If this flag is not set, original behavior is used" + " (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8" + " FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16" + " parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)." ), ) return parser.parse_args() @@ -292,9 +283,10 @@ def dist_print(text, all_ranks=False, no_new_line=False): end = "" if no_new_line else "\n" print(f"[GPU-{LOCAL_RANK}] " + text, end=end) + def get_precision_preset(precision_value): """Get dtype, no_fp8, and recipe based on precision preset. - + Returns: tuple: (dtype, no_fp8, recipe) """ @@ -324,10 +316,10 @@ def get_precision_preset(precision_value): def get_recipe_for_precision(precision_value): """Get FP8 recipe based on precision preset (when FP8 is enabled). - + Args: precision_value: The precision preset string - + Returns: Recipe object for FP8 training """ @@ -342,6 +334,7 @@ def get_recipe_for_precision(precision_value): fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) + def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) @@ -359,40 +352,42 @@ def train(opts): # Use original behavior with dtype and no_fp8 flags dtype = opts.dtype no_fp8 = opts.no_fp8 - + # Set up recipe if FP8 is enabled if not no_fp8: recipe = DelayedScaling( fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: - recipe = None + recipe = None else: # Case 2: Precision preset was explicitly specified # Start with precision preset values preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) - + # Check for incompatible flag combinations # Error if user requests FP8-based precision but also sets --no-fp8 if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: raise ValueError( f"Cannot use --no-fp8 with --precision {opts.precision}. " - f"These flags are incompatible. " + "These flags are incompatible. " f"Either remove --no-fp8 to use {opts.precision} training, " - f"or use --precision fp32/fp16 for non-FP8 training." + "or use --precision fp32/fp16 for non-FP8 training." ) - + dtype = preset_dtype no_fp8 = preset_no_fp8 recipe = preset_recipe - + dist_print(f"Using precision preset: {opts.precision}") - + # Apply explicit dtype override with warning if dtype_explicitly_set: dtype = opts.dtype - dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting") - + dist_print( + f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" + ) + # If FP8 is still enabled, keep recipe based on precision # (dtype only affects parameter storage, not FP8 recipe) if not no_fp8: @@ -505,4 +500,4 @@ def train(opts): # torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init if __name__ == "__main__": args = parse_fsdp_args() - train(args) \ No newline at end of file + train(args) From 76dcb94bbdafd02bdacb9b46ae21b15107cd182f Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 22:55:43 -0500 Subject: [PATCH 16/18] Fix unreachable default case and redundant recipe recreation Remove dead code in get_precision_preset() default case and eliminate redundant recipe recreation when dtype is explicitly overridden, ensuring cleaner logic flow and preventing duplicate recipe instantiation. - Remove unreachable case _: branch from get_precision_preset() function - Delete redundant recipe recreation when dtype_explicitly_set is true - Preserve existing recipe from preset when dtype override occurs - Ensure dtype override only affects parameter storage, not FP8 recipe selection Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index cfa2c135ab..7ebc34afc1 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -306,12 +306,6 @@ def get_precision_preset(precision_value): case "nvfp4": recipe = NVFP4BlockScaling() return torch.bfloat16, False, recipe - case _: - # Default to fp8 behavior - recipe = DelayedScaling( - fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" - ) - return torch.bfloat16, False, recipe def get_recipe_for_precision(precision_value): @@ -395,8 +389,6 @@ def train(opts): # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") - if not no_fp8 and recipe is not None: - dist_print(f"Using FP8 recipe: {type(recipe).__name__}") # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) From e22c2f2702c62b5cce7a2f10c5bb95f4fca6878a Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 23:04:43 -0500 Subject: [PATCH 17/18] Add explicit error handling for invalid precision presets Prevent silent failures when precision validation is bypassed or new presets are added without updating get_precision_preset() function by adding explicit ValueError for unhandled cases. - Add case _: branch to get_precision_preset() that raises ValueError - Ensure invalid precision values fail loudly with clear error message - Prevent TypeError on tuple unpacking if function returns None - Improve maintainability when adding new precision presets Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 7ebc34afc1..61c376e7aa 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -306,7 +306,12 @@ def get_precision_preset(precision_value): case "nvfp4": recipe = NVFP4BlockScaling() return torch.bfloat16, False, recipe - + case _: + # Fail loudly if validation is bypassed or new preset added without updating this function + raise ValueError( + f"Invalid precision preset: {precision_value}. " + f"Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + ) def get_recipe_for_precision(precision_value): """Get FP8 recipe based on precision preset (when FP8 is enabled). From 9d637d56266d72db851ad80039ffd14ae4459f03 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 04:09:02 +0000 Subject: [PATCH 18/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 61c376e7aa..0da5b265d2 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -310,9 +310,10 @@ def get_precision_preset(precision_value): # Fail loudly if validation is bypassed or new preset added without updating this function raise ValueError( f"Invalid precision preset: {precision_value}. " - f"Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" ) + def get_recipe_for_precision(precision_value): """Get FP8 recipe based on precision preset (when FP8 is enabled).