Skip to content
4 changes: 2 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_flex_attention.xml $TE_PATH/tests/pytorch/attention/test_flex_attention.py || test_fail "test_flex_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_linear_mxfp8_attention.xml $TE_PATH/tests/pytorch/attention/test_linear_mxfp8_attention.py || test_fail "test_linear_mxfp8_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
Expand Down
15 changes: 13 additions & 2 deletions tests/pytorch/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import transformer_engine.pytorch as te

from utils import make_recipe
from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV

# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
Expand Down Expand Up @@ -131,8 +132,18 @@ def test_module(self, name: str) -> None:
raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}")
state_dict = torch.load(checkpoint_file, weights_only=False)

# Update module from checkpoint
module.load_state_dict(state_dict, strict=True)
# Update module from checkpoint. Delayed-scaling legacy extra state is unsafe by
# default and requires an explicit opt-in for trusted compatibility artifacts.
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
if quantization == "fp8":
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1"
try:
module.load_state_dict(state_dict, strict=True)
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state


def main() -> None:
Expand Down
13 changes: 12 additions & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from collections.abc import Iterable, Sequence
import io
import os
import math
import random
from typing import Optional
Expand All @@ -18,6 +19,7 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV

from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
Expand Down Expand Up @@ -3217,7 +3219,16 @@ def test_linear(
)
optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
model_load.load_state_dict(state_dict["model"])
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
if quantization in ("fp8", "fp8_delayed_scaling"):
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1"
try:
model_load.load_state_dict(state_dict["model"])
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state
optim_load.load_state_dict(state_dict["optim"])

# Training steps with loaded model
Expand Down
11 changes: 10 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
)
from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand Down Expand Up @@ -847,7 +848,15 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=

del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False))
loaded_state_dict = torch.load(path, weights_only=False)
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
try:
block.load_state_dict(loaded_state_dict)
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state
Comment on lines +852 to +859

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing env-var set makes the save/restore a no-op

The try/finally block saves and unconditionally restores UNSAFE_PICKLE_EXTRA_STATE_ENV, but never actually sets it to "1" before calling load_state_dict. Every other test fixed in this PR (test_checkpoint.py line 136, test_fusible_ops.py line 3222) follows the pattern: save → conditionally set to "1" → try/finally restore. Here the "set" step is absent, so the entire save/restore is a no-op.

_test_e2e_checkpointing_get_model creates a plain TransformerLayer without FP8, so fp8_checkpoint is False and the extra state is an empty tensor today, which avoids the runtime error. If this helper is ever extended with FP8 delayed-scaling (a natural step), load_state_dict will raise a RuntimeError because the env var will never be set.

torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)

Expand Down
89 changes: 89 additions & 0 deletions tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Optional

import pickle

import pytest
import torch
import warnings
Expand Down Expand Up @@ -31,10 +33,19 @@
)
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import (
CustomRecipe,
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
Recipe,
)
from transformer_engine.pytorch._extra_state import (
CheckpointExtraStatePolicy,
UNSAFE_PICKLE_EXTRA_STATE_ENV,
_RECIPE_POLICIES,
should_load_extra_state_pickle,
)

# Check if FP8 is supported
Expand Down Expand Up @@ -691,3 +702,81 @@ def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N):
)
new_dequantized_tensor = new_tensor.dequantize()
torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)


def _custom_recipe_qfactory(_role):
return None


def _recipe_subclasses(cls):
for subcls in cls.__subclasses__():
yield subcls
yield from _recipe_subclasses(subcls)


def _pickled_extra_state_payload(recipe_obj, *, include_delayed_state=False):
state = {"recipe": recipe_obj, "extra_fp8_variables": {}}
if include_delayed_state:
state.update(
{
"scale_fwd": torch.ones(1),
"amax_history_fwd": torch.zeros(1, 1),
"scale_bwd": torch.ones(1),
"amax_history_bwd": torch.zeros(1, 1),
}
)
return pickle.dumps(state)


def test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes():
for cls in _recipe_subclasses(Recipe):
key = ("transformer_engine.common.recipe", cls.__name__)
assert key in _RECIPE_POLICIES
assert _RECIPE_POLICIES[key] in CheckpointExtraStatePolicy


@pytest.mark.parametrize(
"recipe_obj",
[
Float8CurrentScaling(),
MXFP8BlockScaling(),
Float8BlockScaling(),
NVFP4BlockScaling(),
],
)
def test_stateless_pickled_extra_state_is_ignored(recipe_obj):
payload = _pickled_extra_state_payload(recipe_obj)
assert not should_load_extra_state_pickle(payload, "test")


def test_stateless_custom_pickled_extra_state_is_ignored():
payload = _pickled_extra_state_payload(CustomRecipe(qfactory=_custom_recipe_qfactory))
assert not should_load_extra_state_pickle(payload, "test")


@pytest.mark.parametrize("payload", [pickle.dumps({}), pickle.dumps({"extra_fp8_variables": {}})])
def test_global_free_pickled_extra_state_is_ignored(payload):
# Older stateless checkpoints serialized an empty dict. Such a payload
# resolves no globals and cannot execute code, so it must load without the
# unsafe opt-in.
assert not should_load_extra_state_pickle(payload, "test")


@pytest.mark.parametrize(
"payload",
[
_pickled_extra_state_payload(DelayedScaling(), include_delayed_state=True),
_pickled_extra_state_payload(
CustomRecipe(qfactory=_custom_recipe_qfactory), include_delayed_state=True
),
pickle.dumps({"scale_inv_fwd": torch.ones(1), "extra_fp8_variables": {}}),
pickle.dumps({"recipe": object(), "extra_fp8_variables": {}}),
b"not a pickle",
],
)
def test_stateful_unknown_or_malformed_pickled_extra_state_requires_opt_in(payload, monkeypatch):
with pytest.raises(RuntimeError, match=UNSAFE_PICKLE_EXTRA_STATE_ENV):
should_load_extra_state_pickle(payload, "test")

monkeypatch.setenv(UNSAFE_PICKLE_EXTRA_STATE_ENV, "1")
assert should_load_extra_state_pickle(payload, "test")
Loading
Loading