From 806e43887f5bfdf9f0f19c5af98087dfe8cd0873 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 18 Feb 2026 23:04:54 +0000 Subject: [PATCH 1/3] fix: restore requires_grad in transforemrs5 reloading Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 6 ++++- modelopt/torch/speculative/utils.py | 38 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index a880148a7..b57e89a17 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -48,7 +48,10 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import ( + load_vlm_or_llm_with_kwargs, + patch_transformers5_params_loading, +) from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -162,6 +165,7 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: + patch_transformers5_params_loading() _, model = load_vlm_or_llm_with_kwargs( checkpoint, torch_dtype="auto", trust_remote_code=True ) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index e067641ed..f551b4e9c 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -485,3 +485,41 @@ def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): model_cls = transformers.AutoModelForCausalLM return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) + + +def patch_transformers5_params_loading(): + """Patch transformers 5.x parameter loading to preserve original `requires_grad` settings. + + In transformers v5.x, loading a checkpoint forcibly sets parameters' requires_grad, + which may unintentionally unfreeze frozen parameters. This monkey-patch restores the original + `requires_grad` after loading parameters. + + Reference: + https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640 + """ + # Skip patching for non-applicable transformers version + if importlib.util.find_spec("transformers.core_model_loading") is None: + return + from transformers import core_model_loading + + if not hasattr(core_model_loading, "set_param_for_module"): + return + + orig_set_param_for_module = core_model_loading.set_param_for_module + + def patched_set_param_for_module(*args, **kwargs): + """Monkey-patch set_param_for_module to restore original requires_grad.""" + model, target_name = args[:2] + module_path, _, param_name = target_name.rpartition(".") + module_obj = model.get_submodule(module_path) if module_path else model + + # Get original requires_grad value + orig_requires_grad = getattr(module_obj, param_name).requires_grad + + # Call set_param_for_module + orig_set_param_for_module(*args, **kwargs) + + # Restore original requires_grad value + getattr(module_obj, param_name).requires_grad = orig_requires_grad + + core_model_loading.set_param_for_module = patched_set_param_for_module From 836db822ffa3116a68fc160df39335fa25b60146 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 18 Feb 2026 23:37:36 +0000 Subject: [PATCH 2/3] coderabbit Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 8 ++++---- modelopt/torch/speculative/utils.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index b57e89a17..682111184 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -165,10 +165,10 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: - patch_transformers5_params_loading() - _, model = load_vlm_or_llm_with_kwargs( - checkpoint, torch_dtype="auto", trust_remote_code=True - ) + with patch_transformers5_params_loading(): + _, model = load_vlm_or_llm_with_kwargs( + checkpoint, torch_dtype="auto", trust_remote_code=True + ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index f551b4e9c..e34538665 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -487,6 +487,7 @@ def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) +@contextlib.contextmanager def patch_transformers5_params_loading(): """Patch transformers 5.x parameter loading to preserve original `requires_grad` settings. @@ -522,4 +523,8 @@ def patched_set_param_for_module(*args, **kwargs): # Restore original requires_grad value getattr(module_obj, param_name).requires_grad = orig_requires_grad - core_model_loading.set_param_for_module = patched_set_param_for_module + try: + core_model_loading.set_param_for_module = patched_set_param_for_module + yield + finally: + core_model_loading.set_param_for_module = orig_set_param_for_module From 6b5d205e2ef070bdc662f39165f704b2ab1247f6 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 18 Feb 2026 23:48:41 +0000 Subject: [PATCH 3/3] test: resume training Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/launch_train.sh | 2 +- .../speculative_decoding/test_eagle.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c0b9ea00e..ae8a21eea 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -134,7 +134,7 @@ OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} NUM_EPOCHS=${NUM_EPOCHS:-1} SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} LR=${LR:-"1e-4"} -TRAIN_BS=${TRAIN_BS:-4} +TRAIN_BS=${TRAIN_BS:-1} MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 3cbbc69c8..4f80692ca 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -89,7 +89,7 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl "./launch_train.sh", "--model", tiny_llama_path, "--data", tiny_daring_anteater_path, - "--num_epochs", "1", + "--num_epochs", "0.25", "--lr", "1e-5", "--mode", "eagle3", "--eagle_config", str(config_file), @@ -101,6 +101,23 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl ) +def test_resume_training(tiny_daring_anteater_path, eagle_output_dir): + """Test resume training of Eagle3.""" + run_example_command( + [ + "./launch_train.sh", + "--model", eagle_output_dir / "eagle-tinyllama-cp1", + "--data", tiny_daring_anteater_path, + "--num_epochs", "0.5", + "--lr", "1e-5", + "--mode", "eagle3", + "--output_dir", eagle_output_dir / "eagle-tinyllama-cp1", + "--training_seq_len", "128", # Match max_position_embeddings + ], + "speculative_decoding", + ) + + def test_ar_validate(eagle_output_dir): """Test in-framework AR evaluation.""" run_example_command(