-
Notifications
You must be signed in to change notification settings - Fork 634
Add multi-precision training support to FSDP script #2662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c960e6d
5949884
cd97843
1341abb
40919d8
5c1db12
e4846c8
26aee2f
295a106
c9524e3
a506d40
b100f2c
ec31f2a
07a87a7
748bb39
c6fb3a5
da9f82b
2c5473e
a9e664c
196df3d
368820b
9e2f34b
76dcb94
e22c2f2
9d637d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,12 @@ | |
| ) | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| from transformer_engine.common.recipe import Format, DelayedScaling | ||
| from transformer_engine.common.recipe import ( | ||
| Format, | ||
| DelayedScaling, | ||
| MXFP8BlockScaling, | ||
| NVFP4BlockScaling, | ||
| ) | ||
| from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp | ||
|
|
||
| LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) | ||
|
|
@@ -68,6 +73,13 @@ def torch_dtype(d): | |
| return typemap[lowercase(d)] | ||
|
|
||
|
|
||
| def precision(d): | ||
| typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] | ||
| if lowercase(d) not in typemap: | ||
| raise TypeError | ||
| return lowercase(d) | ||
|
|
||
|
|
||
| te_layer_map = { | ||
| "linear": te.Linear, | ||
| "layernorm": te.LayerNorm, | ||
|
|
@@ -91,7 +103,7 @@ def get_layer_args(opts): | |
| hidden_size = opts.num_heads * opts.head_dim | ||
| layer_args = (hidden_size,) | ||
| layer_kwargs = { | ||
| "params_dtype": opts.dtype, | ||
| # "params_dtype": opts.dtype, | ||
| "device": "cuda" if opts.no_defer_init else "meta", | ||
| "get_rng_state_tracker": get_cuda_rng_tracker, | ||
| } | ||
|
|
@@ -112,6 +124,38 @@ def get_layer_args(opts): | |
| return layer_args, layer_kwargs | ||
|
|
||
|
|
||
| class StoreExplicitAction(argparse.Action): | ||
| """Custom action that tracks whether an argument was explicitly set.""" | ||
|
|
||
| def __init__(self, option_strings, dest, type=None, **kwargs): | ||
| super().__init__(option_strings, dest, **kwargs) | ||
| self.type_converter = type # Store the type converter | ||
|
|
||
| def __call__(self, parser, namespace, values, option_string=None): | ||
| # Apply the type converter if one was provided | ||
| if self.type_converter is not None: | ||
| try: | ||
| values = self.type_converter(values) | ||
| except (ValueError, TypeError) as e: | ||
| raise argparse.ArgumentTypeError(f"invalid {self.dest} value: {values}") from e | ||
|
|
||
| setattr(namespace, self.dest, values) | ||
| setattr(namespace, f"{self.dest}_explicitly_set", True) | ||
|
|
||
|
|
||
| class StoreTrueExplicitAction(argparse.Action): | ||
| """Custom action for store_true that tracks whether flag was explicitly set.""" | ||
|
|
||
| def __init__(self, option_strings, dest, default=False, required=False, help=None): | ||
| super().__init__( | ||
| option_strings, dest, nargs=0, const=True, default=default, required=required, help=help | ||
| ) | ||
|
|
||
| def __call__(self, parser, namespace, values, option_string=None): | ||
| setattr(namespace, self.dest, True) | ||
| setattr(namespace, f"{self.dest}_explicitly_set", True) | ||
|
|
||
|
|
||
| def parse_fsdp_args(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Run Transformer Engine modules with the " | ||
|
|
@@ -171,9 +215,21 @@ def parse_fsdp_args(): | |
| ) | ||
| parser.add_argument( | ||
| "--no-fp8", | ||
| action="store_true", | ||
| action=StoreTrueExplicitAction, | ||
| default=False, | ||
| help="Disables the te.autocast() context.", | ||
| help=( | ||
| "Disables the te.autocast() context. When set, FP8 training is disabled and the model" | ||
| " trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is" | ||
| " incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:" | ||
| " Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but" | ||
| " harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR" | ||
| " (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so" | ||
| " disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead" | ||
| " for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)." | ||
| " '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is" | ||
| " the correct way to request non-FP8 training. Default: False (FP8 enabled based on" | ||
| " configuration)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--no-defer-init", | ||
|
|
@@ -189,7 +245,35 @@ def parse_fsdp_args(): | |
| "--dtype", | ||
| type=torch_dtype, | ||
| default=torch.bfloat16, | ||
| help="Data type for input tensor and Transformer Engine module parameters.", | ||
| action=StoreExplicitAction, | ||
| help=( | ||
| "Data type for input tensor and Transformer Engine module parameters. Supported values:" | ||
| " fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag" | ||
| " overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:" | ||
| " Controls parameter dtype directly - With --precision: Overrides preset's dtype but" | ||
| " preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters" | ||
| " (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with" | ||
| " MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype." | ||
| " Default: bfloat16." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default=None, | ||
| help=( | ||
| "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4." | ||
| " This is a convenience flag that configures both dtype and FP8 settings automatically." | ||
| " - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with" | ||
| " DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling" | ||
| " recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16" | ||
| " parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype" | ||
| " from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an" | ||
| " error is raised (incompatible) - If this flag is not set, original behavior is used" | ||
| " (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8" | ||
| " FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16" | ||
| " parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)." | ||
| ), | ||
| ) | ||
| return parser.parse_args() | ||
|
Comment on lines
260
to
278
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conflicting CLI flags |
||
|
|
||
|
|
@@ -200,15 +284,122 @@ def dist_print(text, all_ranks=False, no_new_line=False): | |
| print(f"[GPU-{LOCAL_RANK}] " + text, end=end) | ||
|
|
||
|
|
||
| def get_precision_preset(precision_value): | ||
| """Get dtype, no_fp8, and recipe based on precision preset. | ||
|
|
||
| Returns: | ||
| tuple: (dtype, no_fp8, recipe) | ||
| """ | ||
| match precision_value: | ||
| case "fp32": | ||
| return torch.float32, True, None | ||
| case "fp16": | ||
| return torch.float16, True, None | ||
| case "fp8": | ||
| recipe = DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| return torch.bfloat16, False, recipe | ||
| case "mxfp8": | ||
| recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) | ||
| return torch.bfloat16, False, recipe | ||
| case "nvfp4": | ||
| recipe = NVFP4BlockScaling() | ||
| return torch.bfloat16, False, recipe | ||
| case _: | ||
| # Fail loudly if validation is bypassed or new preset added without updating this function | ||
| raise ValueError( | ||
| f"Invalid precision preset: {precision_value}. " | ||
| "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" | ||
| ) | ||
|
|
||
|
|
||
| def get_recipe_for_precision(precision_value): | ||
| """Get FP8 recipe based on precision preset (when FP8 is enabled). | ||
|
|
||
| Args: | ||
| precision_value: The precision preset string | ||
|
|
||
| Returns: | ||
| Recipe object for FP8 training | ||
| """ | ||
| match precision_value: | ||
| case "mxfp8": | ||
| return MXFP8BlockScaling(fp8_format=Format.E4M3) | ||
| case "nvfp4": | ||
| return NVFP4BlockScaling() | ||
| case _: | ||
| # Default to DelayedScaling for fp8 or any other value | ||
| return DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
|
|
||
|
|
||
| def train(opts): | ||
| # Check which flags were explicitly set | ||
| dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) | ||
| no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) | ||
|
|
||
| # Initialize torch.distributed global process group | ||
| dist.init_process_group(backend="nccl") | ||
| torch.cuda.set_device(LOCAL_RANK) | ||
| dist_print(f"WORLD_SIZE = {WORLD_SIZE}") | ||
| torch.manual_seed(opts.seed) | ||
|
|
||
| # Determine final configuration based on precedence rules | ||
| if opts.precision is None: | ||
| # Case 1: Backward compatibility - no precision preset specified | ||
| # Use original behavior with dtype and no_fp8 flags | ||
| dtype = opts.dtype | ||
| no_fp8 = opts.no_fp8 | ||
|
|
||
| # Set up recipe if FP8 is enabled | ||
| if not no_fp8: | ||
| recipe = DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None | ||
| else: | ||
| # Case 2: Precision preset was explicitly specified | ||
| # Start with precision preset values | ||
| preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) | ||
|
|
||
| # Check for incompatible flag combinations | ||
| # Error if user requests FP8-based precision but also sets --no-fp8 | ||
| if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: | ||
| raise ValueError( | ||
| f"Cannot use --no-fp8 with --precision {opts.precision}. " | ||
| "These flags are incompatible. " | ||
| f"Either remove --no-fp8 to use {opts.precision} training, " | ||
| "or use --precision fp32/fp16 for non-FP8 training." | ||
| ) | ||
|
|
||
| dtype = preset_dtype | ||
| no_fp8 = preset_no_fp8 | ||
| recipe = preset_recipe | ||
|
|
||
| dist_print(f"Using precision preset: {opts.precision}") | ||
|
|
||
| # Apply explicit dtype override with warning | ||
| if dtype_explicitly_set: | ||
| dtype = opts.dtype | ||
| dist_print( | ||
| f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" | ||
| ) | ||
|
|
||
| # If FP8 is still enabled, keep recipe based on precision | ||
| # (dtype only affects parameter storage, not FP8 recipe) | ||
| if not no_fp8: | ||
| recipe = get_recipe_for_precision(opts.precision) | ||
|
Comment on lines
+386
to
+394
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant recipe recreation When |
||
|
|
||
| # Always log the final configuration being used | ||
| dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") | ||
|
|
||
| # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM | ||
| layer_args, layer_kwargs = get_layer_args(opts) | ||
| layer_kwargs["params_dtype"] = dtype | ||
|
|
||
| if opts.num_layers > 1: | ||
| te_layer_list = [] | ||
| for i in range(opts.num_layers): | ||
|
|
@@ -239,7 +430,7 @@ def train(opts): | |
| process_group=all_gpus, | ||
| use_orig_params=True, | ||
| mixed_precision=MixedPrecision( | ||
| param_dtype=opts.dtype, | ||
| param_dtype=dtype, | ||
| reduce_dtype=torch.float32, | ||
| ), | ||
| auto_wrap_policy=fsdp_wrap_policy, | ||
|
|
@@ -258,10 +449,6 @@ def train(opts): | |
| dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") | ||
| dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") | ||
|
|
||
| # Fp8 setup for TE | ||
| fp8_format = Format.HYBRID | ||
| fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") | ||
|
|
||
| # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded | ||
| optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) | ||
|
|
||
|
|
@@ -281,11 +468,11 @@ def train(opts): | |
| opts.seq_length, | ||
| opts.batch_size, | ||
| opts.num_heads * opts.head_dim, | ||
| dtype=opts.dtype, | ||
| dtype=dtype, | ||
| device="cuda", | ||
| ) | ||
| # autocast needs to be given the FSDP process group for amax reductions | ||
| with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): | ||
| with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): | ||
| y = te_model(x) | ||
| loss = y.sum() | ||
| # calculate gradient and take training step outside the autocast context | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing
action=StoreExplicitActionto track explicit--no-fp8usage - required for precedence logic at line 250