Skip to content

Avoid unpickling the extra state when not needed#3123

Open
ptrendx wants to merge 10 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle
Open

Avoid unpickling the extra state when not needed#3123
ptrendx wants to merge 10 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle

Conversation

@ptrendx

@ptrendx ptrendx commented Jun 12, 2026

Copy link
Copy Markdown
Member

Description

Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Avoids unpickling of the stateless recipe extra state
  • Adds a guard and environment variable for the delayed scaling recipes

ptrendx added 2 commits June 12, 2026 05:24
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR avoids unnecessary unpickling of extra state for stateless FP8 recipes (Float8CurrentScaling, MXFP8BlockScaling, etc.) and adds a security guard requiring NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 before loading pickled extra state from delayed-scaling checkpoints.

  • Introduces transformer_engine/pytorch/_extra_state.py with a pickle-inspection classifier (_classify_extra_state_pickle_impl) that scans pickle opcodes without executing them to determine whether a payload is safe to skip, safe to load with opt-in, or genuinely stateless.
  • Updates get_extra_state/set_extra_state in both module/base.py and ops/op.py to short-circuit for stateless recipes and gate unpickling on the new env-var opt-in; test helpers in test_checkpoint.py and test_fusible_ops.py are updated correctly, while test_numerics.py adds the save/restore scaffolding but omits the actual env-var set (noted in a prior review thread).

Confidence Score: 3/5

The checkpoint save path in ops/op.py writes pickle bytes for any non-stateless recipe, including non-delayed CustomRecipe, but the load path unconditionally drops that data — making the round-trip lossy for custom recipe users.

Both ops/op.py and module/base.py (flagged in a prior review) share the same save/load asymmetry for CustomRecipe without delayed scaling: get_extra_state serializes recipe and extra_fp8_variables, while set_extra_state's classifier returns IGNORE and discards them. Additionally, test_numerics.py never sets the env var it saves and restores, leaving that test helper broken for any future FP8 delayed-scaling extension. Multiple defects on core checkpoint round-trip paths lower confidence.

transformer_engine/pytorch/ops/op.py (get_extra_state/set_extra_state asymmetry for CustomRecipe) and tests/pytorch/test_numerics.py (missing env-var assignment before load_state_dict)

Important Files Changed

Filename Overview
transformer_engine/pytorch/_extra_state.py New module introducing pickle-inspection-based classification logic; the DYNAMIC/CustomRecipe case without delayed state silently drops checkpoint data (IGNORE classification), and _stack_global_args uses a running string list rather than a true stack-aware tracker.
transformer_engine/pytorch/ops/op.py get_extra_state still serializes state for non-delayed CustomRecipe, but set_extra_state silently drops it (IGNORE classification) — same save/load asymmetry as the already-flagged base.py case.
transformer_engine/pytorch/module/base.py Adds stateless-recipe guard in get_extra_state and routes set_extra_state through should_load_extra_state_pickle; the DYNAMIC/non-delayed asymmetry was flagged in a previous review comment.
tests/pytorch/test_recipe.py Good new unit tests for the classifier; the CustomRecipe-without-delayed test asserts IGNORE, which inadvertently documents the silent-drop behavior but does not fail on it.
tests/pytorch/test_numerics.py Adds save/restore scaffolding for UNSAFE_PICKLE_EXTRA_STATE_ENV but never actually sets the env var to "1" before load_state_dict, making the entire block a no-op (flagged in a prior review thread).
tests/pytorch/test_checkpoint.py Correctly gates the env-var opt-in on quantization=="fp8" and properly saves/restores the previous env value.
tests/pytorch/test_fusible_ops.py Correctly enables the env-var opt-in for fp8/fp8_delayed_scaling quantization modes with proper save/restore of the previous value.
qa/L0_pytorch_unittest/test.sh Adds NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 to attention test invocations that checkpoint delayed-scaling FP8 state.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[set_extra_state called] --> B{state is None or empty tensor?}
    B -- Yes --> Z[return early]
    B -- No --> C{isinstance io.BytesIO?}
    C -- Yes --> D{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
    D -- No --> E[raise RuntimeError with advisory]
    D -- Yes --> F[torch.load state]
    C -- No --> G[extract state_bytes]
    G --> H[_classify_extra_state_pickle]
    H --> I{has_global?}
    I -- No --> J[IGNORE: return False]
    I -- Yes --> K{has_recipe_key?}
    K -- No --> L[UNSAFE_LOAD]
    K -- Yes --> M{STATEFUL_FP8_DELAYED_SCALING in policies?}
    M -- Yes --> L
    M -- No --> N{has_delayed_state_keys?}
    N -- Yes --> L
    N -- No --> O{policies empty?}
    O -- Yes --> L
    O -- No --> J
    L --> P{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
    P -- No --> E
    P -- Yes --> Q[pickle.loads state]
    J --> Z
    Q --> R[restore fp8_meta recipe and tensors]
    F --> R
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[set_extra_state called] --> B{state is None or empty tensor?}
    B -- Yes --> Z[return early]
    B -- No --> C{isinstance io.BytesIO?}
    C -- Yes --> D{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
    D -- No --> E[raise RuntimeError with advisory]
    D -- Yes --> F[torch.load state]
    C -- No --> G[extract state_bytes]
    G --> H[_classify_extra_state_pickle]
    H --> I{has_global?}
    I -- No --> J[IGNORE: return False]
    I -- Yes --> K{has_recipe_key?}
    K -- No --> L[UNSAFE_LOAD]
    K -- Yes --> M{STATEFUL_FP8_DELAYED_SCALING in policies?}
    M -- Yes --> L
    M -- No --> N{has_delayed_state_keys?}
    N -- Yes --> L
    N -- No --> O{policies empty?}
    O -- Yes --> L
    O -- No --> J
    L --> P{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
    P -- No --> E
    P -- Yes --> Q[pickle.loads state]
    J --> Z
    Q --> R[restore fp8_meta recipe and tensors]
    F --> R
Loading

Reviews (8): Last reviewed commit: "Fix" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/_extra_state.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.

Comment thread transformer_engine/common/recipe/__init__.py Outdated
Comment thread transformer_engine/pytorch/_extra_state.py Outdated
"""

STATELESS = "stateless"
STATEFUL = "stateful"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.

Suggested change
STATEFUL = "stateful"
STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling"

Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment thread transformer_engine/pytorch/_extra_state.py
@ksivaman

Copy link
Copy Markdown
Member

/te-ci pytorch

ksivaman added 2 commits June 23, 2026 09:13
Signed-off-by: ksivamani <ksivamani@nvidia.com>
@ksivaman

Copy link
Copy Markdown
Member

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +852 to +859
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
try:
block.load_state_dict(loaded_state_dict)
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing env-var set makes the save/restore a no-op

The try/finally block saves and unconditionally restores UNSAFE_PICKLE_EXTRA_STATE_ENV, but never actually sets it to "1" before calling load_state_dict. Every other test fixed in this PR (test_checkpoint.py line 136, test_fusible_ops.py line 3222) follows the pattern: save → conditionally set to "1" → try/finally restore. Here the "set" step is absent, so the entire save/restore is a no-op.

_test_e2e_checkpointing_get_model creates a plain TransformerLayer without FP8, so fp8_checkpoint is False and the extra state is an empty tensor today, which avoids the runtime error. If this helper is ever extended with FP8 delayed-scaling (a natural step), load_state_dict will raise a RuntimeError because the env var will never be set.

ksivaman added 2 commits June 24, 2026 20:54
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants