Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
NUM_EPOCHS=${NUM_EPOCHS:-1}
SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
TRAIN_BS=${TRAIN_BS:-1}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
Expand Down
12 changes: 8 additions & 4 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs
from modelopt.torch.speculative.utils import (
load_vlm_or_llm_with_kwargs,
patch_transformers5_params_loading,
)
from modelopt.torch.utils import print_rank_0

torch.manual_seed(0)
Expand Down Expand Up @@ -162,9 +165,10 @@ def train():
use_offline_training = data_args.offline_data_path is not None

if checkpoint:
_, model = load_vlm_or_llm_with_kwargs(
checkpoint, torch_dtype="auto", trust_remote_code=True
)
with patch_transformers5_params_loading():
_, model = load_vlm_or_llm_with_kwargs(
checkpoint, torch_dtype="auto", trust_remote_code=True
)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
else:
# To avoid OOM for large models, we load and convert model on CPU first.
Expand Down
43 changes: 43 additions & 0 deletions modelopt/torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,46 @@ def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs):
model_cls = transformers.AutoModelForCausalLM

return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs)


@contextlib.contextmanager
def patch_transformers5_params_loading():
"""Patch transformers 5.x parameter loading to preserve original `requires_grad` settings.

In transformers v5.x, loading a checkpoint forcibly sets parameters' requires_grad,
which may unintentionally unfreeze frozen parameters. This monkey-patch restores the original
`requires_grad` after loading parameters.

Reference:
https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640
"""
# Skip patching for non-applicable transformers version
if importlib.util.find_spec("transformers.core_model_loading") is None:
return
from transformers import core_model_loading

if not hasattr(core_model_loading, "set_param_for_module"):
return

orig_set_param_for_module = core_model_loading.set_param_for_module

def patched_set_param_for_module(*args, **kwargs):
"""Monkey-patch set_param_for_module to restore original requires_grad."""
model, target_name = args[:2]
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model

# Get original requires_grad value
orig_requires_grad = getattr(module_obj, param_name).requires_grad

# Call set_param_for_module
orig_set_param_for_module(*args, **kwargs)

# Restore original requires_grad value
getattr(module_obj, param_name).requires_grad = orig_requires_grad

try:
core_model_loading.set_param_for_module = patched_set_param_for_module
yield
finally:
core_model_loading.set_param_for_module = orig_set_param_for_module
19 changes: 18 additions & 1 deletion tests/examples/speculative_decoding/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl
"./launch_train.sh",
"--model", tiny_llama_path,
"--data", tiny_daring_anteater_path,
"--num_epochs", "1",
"--num_epochs", "0.25",
"--lr", "1e-5",
"--mode", "eagle3",
"--eagle_config", str(config_file),
Expand All @@ -101,6 +101,23 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl
)


def test_resume_training(tiny_daring_anteater_path, eagle_output_dir):
"""Test resume training of Eagle3."""
run_example_command(
[
"./launch_train.sh",
"--model", eagle_output_dir / "eagle-tinyllama-cp1",
"--data", tiny_daring_anteater_path,
"--num_epochs", "0.5",
"--lr", "1e-5",
"--mode", "eagle3",
"--output_dir", eagle_output_dir / "eagle-tinyllama-cp1",
"--training_seq_len", "128", # Match max_position_embeddings
],
"speculative_decoding",
)


def test_ar_validate(eagle_output_dir):
"""Test in-framework AR evaluation."""
run_example_command(
Expand Down