From ebeba6f7f7bea2aaf304645803bb5c0808a0ac31 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Fri, 1 May 2026 00:22:51 -0700 Subject: [PATCH 1/2] Add MXFP8 recipe to fully_shard example (reproduces checkpoint crash) --- examples/pytorch/quantized_model_init/fully_shard.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/quantized_model_init/fully_shard.py b/examples/pytorch/quantized_model_init/fully_shard.py index 2b5ca84ebc..87598552dd 100644 --- a/examples/pytorch/quantized_model_init/fully_shard.py +++ b/examples/pytorch/quantized_model_init/fully_shard.py @@ -43,7 +43,7 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.quantized_tensor import QuantizedTensor - +from transformer_engine.common.recipe import MXFP8BlockScaling # ── Configuration ──────────────────────────────────────────────────── HIDDEN_SIZE = 256 FFN_HIDDEN_SIZE = 1024 @@ -85,7 +85,7 @@ def main(): # avoiding the precision loss of dequantizing from FP8. # We set DTYPE to float32 since these weights will actually be initialized as FP8, # but we want to seed the optimizer states (which will be in FP32) with the FP32 values. - with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True): + with te.quantized_model_init(recipe=MXFP8BlockScaling(), enabled=True, preserve_high_precision_init_val=True): model = torch.nn.Sequential( *[ te.TransformerLayer( @@ -154,7 +154,7 @@ def main(): for step in range(NUM_STEPS): optimizer.zero_grad(set_to_none=True) - with te.autocast(enabled=True): + with te.autocast(enabled=True, recipe=MXFP8BlockScaling()): output = model(x) loss = F.mse_loss(output, target) @@ -187,7 +187,7 @@ def main(): # Verify training continues after checkpoint load. optimizer.zero_grad(set_to_none=True) - with te.autocast(enabled=True): + with te.autocast(enabled=True, recipe=MXFP8BlockScaling()): output = model(x) loss = F.mse_loss(output, target) loss.backward() From 3feb257f35f9973494853330a65a1926ca59f81d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 May 2026 08:14:56 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/quantized_model_init/fully_shard.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/quantized_model_init/fully_shard.py b/examples/pytorch/quantized_model_init/fully_shard.py index 87598552dd..1a8b716b22 100644 --- a/examples/pytorch/quantized_model_init/fully_shard.py +++ b/examples/pytorch/quantized_model_init/fully_shard.py @@ -43,7 +43,8 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.quantized_tensor import QuantizedTensor -from transformer_engine.common.recipe import MXFP8BlockScaling +from transformer_engine.common.recipe import MXFP8BlockScaling + # ── Configuration ──────────────────────────────────────────────────── HIDDEN_SIZE = 256 FFN_HIDDEN_SIZE = 1024 @@ -85,7 +86,9 @@ def main(): # avoiding the precision loss of dequantizing from FP8. # We set DTYPE to float32 since these weights will actually be initialized as FP8, # but we want to seed the optimizer states (which will be in FP32) with the FP32 values. - with te.quantized_model_init(recipe=MXFP8BlockScaling(), enabled=True, preserve_high_precision_init_val=True): + with te.quantized_model_init( + recipe=MXFP8BlockScaling(), enabled=True, preserve_high_precision_init_val=True + ): model = torch.nn.Sequential( *[ te.TransformerLayer(