Avoid unpickling the extra state when not needed#3123
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Greptile SummaryThis PR avoids unnecessary unpickling of extra state for stateless FP8 recipes (Float8CurrentScaling, MXFP8BlockScaling, etc.) and adds a security guard requiring
Confidence Score: 3/5The checkpoint save path in ops/op.py writes pickle bytes for any non-stateless recipe, including non-delayed CustomRecipe, but the load path unconditionally drops that data — making the round-trip lossy for custom recipe users. Both ops/op.py and module/base.py (flagged in a prior review) share the same save/load asymmetry for CustomRecipe without delayed scaling: get_extra_state serializes recipe and extra_fp8_variables, while set_extra_state's classifier returns IGNORE and discards them. Additionally, test_numerics.py never sets the env var it saves and restores, leaving that test helper broken for any future FP8 delayed-scaling extension. Multiple defects on core checkpoint round-trip paths lower confidence. transformer_engine/pytorch/ops/op.py (get_extra_state/set_extra_state asymmetry for CustomRecipe) and tests/pytorch/test_numerics.py (missing env-var assignment before load_state_dict) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[set_extra_state called] --> B{state is None or empty tensor?}
B -- Yes --> Z[return early]
B -- No --> C{isinstance io.BytesIO?}
C -- Yes --> D{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
D -- No --> E[raise RuntimeError with advisory]
D -- Yes --> F[torch.load state]
C -- No --> G[extract state_bytes]
G --> H[_classify_extra_state_pickle]
H --> I{has_global?}
I -- No --> J[IGNORE: return False]
I -- Yes --> K{has_recipe_key?}
K -- No --> L[UNSAFE_LOAD]
K -- Yes --> M{STATEFUL_FP8_DELAYED_SCALING in policies?}
M -- Yes --> L
M -- No --> N{has_delayed_state_keys?}
N -- Yes --> L
N -- No --> O{policies empty?}
O -- Yes --> L
O -- No --> J
L --> P{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
P -- No --> E
P -- Yes --> Q[pickle.loads state]
J --> Z
Q --> R[restore fp8_meta recipe and tensors]
F --> R
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[set_extra_state called] --> B{state is None or empty tensor?}
B -- Yes --> Z[return early]
B -- No --> C{isinstance io.BytesIO?}
C -- Yes --> D{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
D -- No --> E[raise RuntimeError with advisory]
D -- Yes --> F[torch.load state]
C -- No --> G[extract state_bytes]
G --> H[_classify_extra_state_pickle]
H --> I{has_global?}
I -- No --> J[IGNORE: return False]
I -- Yes --> K{has_recipe_key?}
K -- No --> L[UNSAFE_LOAD]
K -- Yes --> M{STATEFUL_FP8_DELAYED_SCALING in policies?}
M -- Yes --> L
M -- No --> N{has_delayed_state_keys?}
N -- Yes --> L
N -- No --> O{policies empty?}
O -- Yes --> L
O -- No --> J
L --> P{NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1?}
P -- No --> E
P -- Yes --> Q[pickle.loads state]
J --> Z
Q --> R[restore fp8_meta recipe and tensors]
F --> R
Reviews (8): Last reviewed commit: "Fix" | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.
| """ | ||
|
|
||
| STATELESS = "stateless" | ||
| STATEFUL = "stateful" |
There was a problem hiding this comment.
We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.
| STATEFUL = "stateful" | |
| STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling" |
Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: ksivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: ksivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
| old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) | ||
| try: | ||
| block.load_state_dict(loaded_state_dict) | ||
| finally: | ||
| if old_unsafe_extra_state is None: | ||
| os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) | ||
| else: | ||
| os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state |
There was a problem hiding this comment.
Missing env-var set makes the save/restore a no-op
The try/finally block saves and unconditionally restores UNSAFE_PICKLE_EXTRA_STATE_ENV, but never actually sets it to "1" before calling load_state_dict. Every other test fixed in this PR (test_checkpoint.py line 136, test_fusible_ops.py line 3222) follows the pattern: save → conditionally set to "1" → try/finally restore. Here the "set" step is absent, so the entire save/restore is a no-op.
_test_e2e_checkpointing_get_model creates a plain TransformerLayer without FP8, so fp8_checkpoint is False and the extra state is an empty tensor today, which avoids the runtime error. If this helper is ever extended with FP8 delayed-scaling (a natural step), load_state_dict will raise a RuntimeError because the env var will never be set.
Description
Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.
Type of change
Changes
Please list the changes introduced in this PR: