Skip to content

Commit eb99488

Browse files
authored
Fix: restore requires_grad in transformers5 reloading (#907)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Patch transformers 5.x parameter loading to preserve original `requires_grad` settings. In transformers v5.x, loading a checkpoint forcibly sets parameters' requires_grad, which unintentionally unfreeze frozen parameters (e.g. Base model in eagle training). This leads to optimizer initialization error since the restored optimizer expected more parameter than the checkpoint. This monkey-patch restores the original`requires_grad` after loading parameters. Reference: https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640 ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed model parameter loading in speculative decoding to properly preserve gradient requirements for each parameter when using HuggingFace Transformers 5.x, ensuring correct behavior during checkpoint resumption and model initialization. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 3dd52bf commit eb99488

4 files changed

Lines changed: 70 additions & 6 deletions

File tree

examples/speculative_decoding/launch_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
134134
NUM_EPOCHS=${NUM_EPOCHS:-1}
135135
SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
136136
LR=${LR:-"1e-4"}
137-
TRAIN_BS=${TRAIN_BS:-4}
137+
TRAIN_BS=${TRAIN_BS:-1}
138138
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
139139
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
140140
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}

examples/speculative_decoding/main.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848

4949
import modelopt.torch.opt as mto
5050
import modelopt.torch.speculative as mtsp
51-
from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs
51+
from modelopt.torch.speculative.utils import (
52+
load_vlm_or_llm_with_kwargs,
53+
patch_transformers5_params_loading,
54+
)
5255
from modelopt.torch.utils import print_rank_0
5356

5457
torch.manual_seed(0)
@@ -162,9 +165,10 @@ def train():
162165
use_offline_training = data_args.offline_data_path is not None
163166

164167
if checkpoint:
165-
_, model = load_vlm_or_llm_with_kwargs(
166-
checkpoint, torch_dtype="auto", trust_remote_code=True
167-
)
168+
with patch_transformers5_params_loading():
169+
_, model = load_vlm_or_llm_with_kwargs(
170+
checkpoint, torch_dtype="auto", trust_remote_code=True
171+
)
168172
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
169173
else:
170174
# To avoid OOM for large models, we load and convert model on CPU first.

modelopt/torch/speculative/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,46 @@ def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs):
485485
model_cls = transformers.AutoModelForCausalLM
486486

487487
return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs)
488+
489+
490+
@contextlib.contextmanager
491+
def patch_transformers5_params_loading():
492+
"""Patch transformers 5.x parameter loading to preserve original `requires_grad` settings.
493+
494+
In transformers v5.x, loading a checkpoint forcibly sets parameters' requires_grad,
495+
which may unintentionally unfreeze frozen parameters. This monkey-patch restores the original
496+
`requires_grad` after loading parameters.
497+
498+
Reference:
499+
https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640
500+
"""
501+
# Skip patching for non-applicable transformers version
502+
if importlib.util.find_spec("transformers.core_model_loading") is None:
503+
return
504+
from transformers import core_model_loading
505+
506+
if not hasattr(core_model_loading, "set_param_for_module"):
507+
return
508+
509+
orig_set_param_for_module = core_model_loading.set_param_for_module
510+
511+
def patched_set_param_for_module(*args, **kwargs):
512+
"""Monkey-patch set_param_for_module to restore original requires_grad."""
513+
model, target_name = args[:2]
514+
module_path, _, param_name = target_name.rpartition(".")
515+
module_obj = model.get_submodule(module_path) if module_path else model
516+
517+
# Get original requires_grad value
518+
orig_requires_grad = getattr(module_obj, param_name).requires_grad
519+
520+
# Call set_param_for_module
521+
orig_set_param_for_module(*args, **kwargs)
522+
523+
# Restore original requires_grad value
524+
getattr(module_obj, param_name).requires_grad = orig_requires_grad
525+
526+
try:
527+
core_model_loading.set_param_for_module = patched_set_param_for_module
528+
yield
529+
finally:
530+
core_model_loading.set_param_for_module = orig_set_param_for_module

tests/examples/speculative_decoding/test_eagle.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl
8989
"./launch_train.sh",
9090
"--model", tiny_llama_path,
9191
"--data", tiny_daring_anteater_path,
92-
"--num_epochs", "1",
92+
"--num_epochs", "0.25",
9393
"--lr", "1e-5",
9494
"--mode", "eagle3",
9595
"--eagle_config", str(config_file),
@@ -101,6 +101,23 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl
101101
)
102102

103103

104+
def test_resume_training(tiny_daring_anteater_path, eagle_output_dir):
105+
"""Test resume training of Eagle3."""
106+
run_example_command(
107+
[
108+
"./launch_train.sh",
109+
"--model", eagle_output_dir / "eagle-tinyllama-cp1",
110+
"--data", tiny_daring_anteater_path,
111+
"--num_epochs", "0.5",
112+
"--lr", "1e-5",
113+
"--mode", "eagle3",
114+
"--output_dir", eagle_output_dir / "eagle-tinyllama-cp1",
115+
"--training_seq_len", "128", # Match max_position_embeddings
116+
],
117+
"speculative_decoding",
118+
)
119+
120+
104121
def test_ar_validate(eagle_output_dir):
105122
"""Test in-framework AR evaluation."""
106123
run_example_command(

0 commit comments

Comments
 (0)