diff --git a/examples/pytorch/quantized_model_init/fully_shard.py b/examples/pytorch/quantized_model_init/fully_shard.py index 2b5ca84ebc..1a8b716b22 100644 --- a/examples/pytorch/quantized_model_init/fully_shard.py +++ b/examples/pytorch/quantized_model_init/fully_shard.py @@ -43,6 +43,7 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.common.recipe import MXFP8BlockScaling # ── Configuration ──────────────────────────────────────────────────── HIDDEN_SIZE = 256 @@ -85,7 +86,9 @@ def main(): # avoiding the precision loss of dequantizing from FP8. # We set DTYPE to float32 since these weights will actually be initialized as FP8, # but we want to seed the optimizer states (which will be in FP32) with the FP32 values. - with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True): + with te.quantized_model_init( + recipe=MXFP8BlockScaling(), enabled=True, preserve_high_precision_init_val=True + ): model = torch.nn.Sequential( *[ te.TransformerLayer( @@ -154,7 +157,7 @@ def main(): for step in range(NUM_STEPS): optimizer.zero_grad(set_to_none=True) - with te.autocast(enabled=True): + with te.autocast(enabled=True, recipe=MXFP8BlockScaling()): output = model(x) loss = F.mse_loss(output, target) @@ -187,7 +190,7 @@ def main(): # Verify training continues after checkpoint load. optimizer.zero_grad(set_to_none=True) - with te.autocast(enabled=True): + with te.autocast(enabled=True, recipe=MXFP8BlockScaling()): output = model(x) loss = F.mse_loss(output, target) loss.backward()