diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d7aadf994..c35feffe3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -51,8 +51,10 @@ import modelopt.torch.sparsity as mts from modelopt.torch.export import ( export_hf_checkpoint, + export_speculative_decoding, export_tensorrt_llm_checkpoint, get_model_type, + has_spec_opt, save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model @@ -566,6 +568,13 @@ def export_quantized( export_path = args.export_path + # Early exit for speculative decoding checkpoints + # No tokenizer saving needed for spec ckpts + if has_spec_opt(full_model): + export_speculative_decoding(full_model, export_dir=export_path) + print(f"Quantized speculative decoding checkpoint exported to: {export_path}") + return + # Check if the model is a multimodal/VLM model is_vlm = is_multimodal_model(full_model) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index fc3421583..23a7560f7 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -20,7 +20,7 @@ import torch import modelopt.torch.opt as mto -from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.export import export_speculative_decoding from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs @@ -41,7 +41,7 @@ def parse_args(): _, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): - export_hf_checkpoint( + export_speculative_decoding( model, export_dir=args.export_path, ) diff --git a/modelopt/torch/export/plugins/hf_spec_configs.py b/modelopt/torch/export/plugins/hf_spec_configs.py new file mode 100644 index 000000000..b78dfadd4 --- /dev/null +++ b/modelopt/torch/export/plugins/hf_spec_configs.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Template config for speculative decoding exporting.""" + +llama_eagle_template_config = { + "architectures": ["LlamaForCausalLMEagle3"], + "bos_token_id": None, + "eos_token_id": None, + "hidden_act": None, + "hidden_size": None, + "initializer_range": None, + "intermediate_size": None, + "max_position_embeddings": None, + "model_type": "llama", + "num_attention_heads": None, + "num_key_value_heads": None, + "num_hidden_layers": None, + "pad_token_id": None, + "rms_norm_eps": None, + "tie_word_embeddings": False, + "torch_dtype": None, + "transformers_version": None, + "use_cache": None, + "vocab_size": None, + "draft_vocab_size": None, + "rope_scaling": None, + "attention_bias": None, + "attention_dropout": None, + "head_dim": None, + "mlp_bias": None, + "pretraining_tp": None, + "rope_theta": None, + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": None, + "use_aux_hidden_state": None, + "use_input_layernorm_in_first_layer": None, + "use_last_layernorm": None, + "use_mtp_layernorm": None, + "next_layer_regular": True, + "parallel_draft_step": None, + "parallel_draft_heads_num_layers": None, + }, +} + +kimik2_eagle_template_config = { + "architectures": ["Eagle3DeepseekV2ForCausalLM"], + "attention_bias": None, + "attention_dropout": None, + "aux_loss_alpha": None, + "bos_token_id": None, + "chunk_size_feed_forward": None, + "diversity_penalty": None, + "do_sample": None, + "early_stopping": None, + "encoder_no_repeat_ngram_size": None, + "eos_token_id": None, + "ep_size": None, + "first_k_dense_replace": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": None, + "hidden_size": None, + "id2label": None, + "initializer_range": None, + "intermediate_size": None, + "is_decoder": None, + "is_encoder_decoder": None, + "kv_lora_rank": None, + "label2id": None, + "length_penalty": None, + "max_length": None, + "max_position_embeddings": None, + "min_length": None, + "model_type": "kimi_k2", + "moe_intermediate_size": None, + "moe_layer_freq": None, + "n_group": None, + "n_routed_experts": None, + "n_shared_experts": None, + "no_repeat_ngram_size": None, + "norm_topk_prob": None, + "num_attention_heads": None, + "num_beam_groups": None, + "num_beams": None, + "num_experts_per_tok": None, + "num_hidden_layers": None, + "num_key_value_heads": None, + "num_nextn_predict_layers": None, + "num_return_sequences": None, + "output_attentions": None, + "output_hidden_states": None, + "output_scores": None, + "pad_token_id": None, + "pretraining_tp": None, + "pruned_heads": None, + "q_lora_rank": None, + "qk_nope_head_dim": None, + "qk_rope_head_dim": None, + "remove_invalid_values": None, + "repetition_penalty": None, + "return_dict": None, + "return_dict_in_generate": None, + "rms_norm_eps": None, + "rope_scaling": None, + "rope_theta": None, + "routed_scaling_factor": None, + "scoring_func": None, + "sep_token_id": None, + "seq_aux": None, + "temperature": None, + "tf_legacy_loss": None, + "tie_encoder_decoder": None, + "tie_word_embeddings": None, + "top_k": None, + "top_p": None, + "topk_group": None, + "topk_method": None, + "torch_dtype": None, + "torchscript": None, + "transformers_version": None, + "typical_p": None, + "use_bfloat16": None, + "use_cache": None, + "v_head_dim": None, + "vocab_size": None, + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": None, + "use_aux_hidden_state": None, + "use_input_layernorm_in_first_layer": None, + "use_last_layernorm": None, + "use_mtp_layernorm": None, + "next_layer_regular": True, + "parallel_draft_step": None, + "parallel_draft_heads_num_layers": None, + }, +} diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index cdb009003..d287b7474 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -16,122 +16,191 @@ """Modify state_dict and config for exporting speculative decoding in official format.""" import re -from copy import copy +from copy import deepcopy import torch import torch.nn as nn +from .hf_spec_configs import kimik2_eagle_template_config, llama_eagle_template_config + +ALL_SPEC_MODES = ["eagle"] + LLAMA_EAGLE_SINGLE_LAYER = { "required": { - "midlayer.self_attn.q_proj.weight", - "midlayer.self_attn.k_proj.weight", - "midlayer.self_attn.v_proj.weight", - "midlayer.self_attn.o_proj.weight", - "midlayer.mlp.gate_proj.weight", - "midlayer.mlp.up_proj.weight", - "midlayer.mlp.down_proj.weight", - "midlayer.hidden_norm.weight", - "midlayer.input_layernorm.weight", - "midlayer.post_attention_layernorm.weight", - "norm.weight", - "fc.weight", + "layers.0.self_attn.q_proj", + "layers.0.self_attn.k_proj", + "layers.0.self_attn.v_proj", + "layers.0.self_attn.o_proj", + "layers.0.mlp.gate_proj", + "layers.0.mlp.up_proj", + "layers.0.mlp.down_proj", + "layers.0.hidden_norm", + "layers.0.input_layernorm", + "layers.0.post_attention_layernorm", + "norm", + "fc", }, - "optional": {"d2t", "lm_head.weight"}, + "optional": {"d2t", "lm_head"}, } KIMIK2_EAGLE_SINGLE_LAYER = { "required": { - "midlayer.self_attn.kv_a_layernorm.weight", - "midlayer.self_attn.q_a_layernorm.weight", - "midlayer.self_attn.q_a_proj.weight", - "midlayer.self_attn.q_b_proj.weight", - "midlayer.self_attn.kv_a_proj_with_mqa.weight", - "midlayer.self_attn.kv_b_proj.weight", - "midlayer.self_attn.o_proj.weight", - "midlayer.mlp.gate_proj.weight", - "midlayer.mlp.up_proj.weight", - "midlayer.mlp.down_proj.weight", - "midlayer.hidden_norm.weight", - "midlayer.input_layernorm.weight", - "midlayer.post_attention_layernorm.weight", - "norm.weight", - "fc.weight", + "layers.0.self_attn.kv_a_layernorm", + "layers.0.self_attn.q_a_layernorm", + "layers.0.self_attn.q_a_proj", + "layers.0.self_attn.q_b_proj", + "layers.0.self_attn.kv_a_proj_with_mqa", + "layers.0.self_attn.kv_b_proj", + "layers.0.self_attn.o_proj", + "layers.0.mlp.gate_proj", + "layers.0.mlp.up_proj", + "layers.0.mlp.down_proj", + "layers.0.hidden_norm", + "layers.0.input_layernorm", + "layers.0.post_attention_layernorm", + "norm", + "fc", }, "optional": { "d2t", - "lm_head.weight", + "lm_head", }, } -def _check_valid_sd(state_dict: dict, eagle_decoder_type: str, num_hidden_layers: int): - """Check the export state dict is valid, otherwise raise Exception.""" - expected_keys_single_layer = { - "llama": LLAMA_EAGLE_SINGLE_LAYER, - "kimik2": KIMIK2_EAGLE_SINGLE_LAYER, - }[eagle_decoder_type] - # Check that export sd has required keys - if num_hidden_layers == 1: - for key in expected_keys_single_layer["required"]: - assert key in state_dict, f"Missing required key: {key}" - else: +def has_spec_opt(model: nn.Module): + """Check if the model has speculative decoding optimization.""" + opt_modes = getattr(model, "_modelopt_state", []) + return any(mode[0] in ALL_SPEC_MODES for mode in opt_modes) + + +def has_quant_opt(model: nn.Module): + """Check if the model has quantization optimization.""" + opt_modes = getattr(model, "_modelopt_state", []) + return any(mode[0] == "quantize" for mode in opt_modes) + + +class EagleExporter: + """Draft model exporter for Eagle.""" + + def __init__(self, model: nn.Module, dtype: torch.dtype | None = None): + """Initialize the EagleExporter.""" + self.model = model + self.eagle_decoder_type = model.eagle_config.eagle_decoder_type + self.num_hidden_layers = model.eagle_config.num_hidden_layers + if has_quant_opt(model): + from ..unified_export_hf import _export_transformers_checkpoint + + self.state_dict, self.hf_quant_config = _export_transformers_checkpoint(model, dtype) + else: + self.state_dict, self.hf_quant_config = model.state_dict(), None + + def _check_valid_sd(self, export_sd: dict): + """Check the export state dict is valid, otherwise raise Exception.""" + expected_keys_single_layer = { + "llama": LLAMA_EAGLE_SINGLE_LAYER, + "kimik2": KIMIK2_EAGLE_SINGLE_LAYER, + }[self.eagle_decoder_type] + # Check that export sd has required keys for key in expected_keys_single_layer["required"]: - assert key.replace("midlayer", "midlayer.0") in state_dict, ( - f"Missing required key: {key}" - ) - for i in range(1, num_hidden_layers): + assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight" + for i in range(1, self.num_hidden_layers): for key in expected_keys_single_layer["required"] - { - "midlayer.hidden_norm.weight", - "midlayer.input_layernorm.weight", - "norm.weight", - "fc.weight", + "layers.0.hidden_norm", + "layers.0.input_layernorm", + "norm", + "fc", }: - assert key.replace("midlayer", f"midlayer.{i}") in state_dict, ( - f"Missing required key: {key}" + assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, ( + f"Missing required key: {key}.weight" ) - # Check that export sd has no unexpected keys - allowed_keys_single_layer = ( - expected_keys_single_layer["required"] | expected_keys_single_layer["optional"] - ) - if num_hidden_layers == 1: - for key in state_dict: - assert key in allowed_keys_single_layer, f"Unexpected key: {key}" - else: - for key in state_dict: - assert re.sub(r"midlayers\.\d+\.", "", key) in { - k.replace("midlayer.", "") for k in allowed_keys_single_layer - }, f"Unexpected key: {key}" - - -def spec_opt_only(model: nn.Module): - """Check if the model have only speculative decoding optimization.""" - opt_modes = getattr(model, "_modelopt_state", None) - return ( - isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle" - ) - - -def export_spec_ckpt_state_dict(model: nn.Module): - """Only return the state dict of the draft model in official format and ignore the base model.""" - # check the model has only speculative decoding - assert spec_opt_only(model), "Not purely eagle model." - - # Rename layers to midlayer - if model.eagle_config.num_hidden_layers == 1: - model.eagle_module.midlayer = model.eagle_module._modules.pop("layers")[0] - else: - model.eagle_module.midlayer = model.eagle_module._modules.pop("layers") - export_sd = copy(model.eagle_module.state_dict()) - - # Use base model's lm head if draft model doesn't have one - if "lm_head.weight" not in export_sd: - export_sd["lm_head.weight"] = model.state_dict()["lm_head.weight"] - - # Rename parallel draft weights - if model.eagle_config.parallel_draft_step > 1: - for i in range(model.eagle_config.parallel_draft_step - 1): - for j in range(model.eagle_config.parallel_draft_heads_num_layers): + # Check that export sd has no unexpected keys + # Note that quantized eagle are allowed to have scales + allowed_keys_single_layer = ( + expected_keys_single_layer["required"] | expected_keys_single_layer["optional"] + ) + for key in export_sd: + assert ( + re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) + in allowed_keys_single_layer + ), f"Unexpected key: {key}" + + def extract_state_dict(self): + """Extract the state dict of the draft model in deployment format.""" + export_sd = {} + for key in self.state_dict: + if "eagle_module" in key or "lm_head" in key: + export_key = key.replace("eagle_module.", "") + export_sd[export_key] = self.state_dict[key].clone() + # Use base model's lm head if draft model doesn't have one + if "lm_head.weight" not in export_sd: + export_sd["lm_head.weight"] = self.state_dict["lm_head.weight"] + + self._check_valid_sd(export_sd) + + return export_sd + + def export_config(self, model): + """Export config.json in deployment format.""" + template_config: dict = { + "llama": llama_eagle_template_config, + "kimik2": kimik2_eagle_template_config, + }[model.eagle_config.eagle_decoder_type] + template_config = deepcopy(template_config) + + def _get_config_from_draft_or_base(key: str, model: nn.Module): + if getattr(model._draft_model_config, key, None) is not None: + return getattr(model._draft_model_config, key) + elif getattr(model.config, key, None) is not None: + return getattr(model.config, key) + else: + return None + + for key in template_config: + value = template_config[key] + if isinstance(value, dict): + # for eagle config, we find it in model.eagle_config + for sub_key in value: + if value[sub_key] is None: + value[sub_key] = _get_config_from_draft_or_base(sub_key, model) + elif value is None: + # First, we try to load fron eagle config. + new_value = _get_config_from_draft_or_base(key, model) + # If the value is a torch.dtype, we convert to string for serialization. + if isinstance(new_value, torch.dtype): + new_value = str(new_value).replace("torch.", "") + template_config[key] = new_value + + if self.hf_quant_config is not None: + template_config["quantization_config"] = self.hf_quant_config + + return template_config + + def export_quant_config(self): + """Export hf_quant_config.json.""" + return deepcopy(self.hf_quant_config) + + +class EagleMedusaExporter(EagleExporter): + """Draft model exporter for EagleMedusa.""" + + def __init__(self, model: nn.Module, dtype: torch.dtype | None = None): + """Initialize the EagleMedusaExporter.""" + super().__init__(model, dtype) + self.parallel_draft_step = model.eagle_config.parallel_draft_step + self.parallel_draft_heads_num_layers = model.eagle_config.parallel_draft_heads_num_layers + # NOTE: tmp: bypassing format check for parallel draft + self._check_valid_sd = lambda *args, **kwargs: None + + def extract_state_dict(self): + """Extract the state dict of the draft model in deployment format.""" + export_sd = super().extract_state_dict() + if self.parallel_draft_step <= 1: + return export_sd + + for i in range(self.parallel_draft_step - 1): + for j in range(self.parallel_draft_heads_num_layers): export_sd[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.weight"] = ( export_sd.pop(f"parallel_draft_heads.medusa_heads.{i}.{j}.linear.weight") ) @@ -143,180 +212,4 @@ def export_spec_ckpt_state_dict(model: nn.Module): export_sd["parallel_draft_heads.lm_head.weight"] = export_sd.pop( "parallel_draft_heads.lm_head.weight" ) - # NOTE: tmp: bypassing format check for parallel draft return export_sd - - _check_valid_sd( - export_sd, model.eagle_config.eagle_decoder_type, model.eagle_config.num_hidden_layers - ) - - return export_sd - - -def export_spec_ckpt_config(model: nn.Module): - """Return the config of draft model in official format.""" - assert spec_opt_only(model), "Not purely eagle model." - - # This is the config keys in official checkpoint. - llama_eagle_template_config = { - "architectures": ["LlamaForCausalLMEagle3"], - "bos_token_id": None, - "eos_token_id": None, - "hidden_act": None, - "hidden_size": None, - "initializer_range": None, - "intermediate_size": None, - "max_position_embeddings": None, - "model_type": "llama", - "num_attention_heads": None, - "num_key_value_heads": None, - "num_hidden_layers": None, - "pad_token_id": None, - "rms_norm_eps": None, - "tie_word_embeddings": False, - "torch_dtype": None, - "transformers_version": None, - "use_cache": None, - "vocab_size": None, - "draft_vocab_size": None, - "rope_scaling": None, - "attention_bias": None, - "attention_dropout": None, - "head_dim": None, - "mlp_bias": None, - "pretraining_tp": None, - "rope_theta": None, - "eagle_config": { - "eagle_aux_hidden_state_layer_ids": None, - "use_aux_hidden_state": None, - "use_input_layernorm_in_first_layer": None, - "use_last_layernorm": None, - "use_mtp_layernorm": None, - "next_layer_regular": True, - "parallel_draft_step": None, - "parallel_draft_heads_num_layers": None, - }, - } - - kimik2_eagle_template_config = { - "architectures": ["Eagle3DeepseekV2ForCausalLM"], - "attention_bias": None, - "attention_dropout": None, - "aux_loss_alpha": None, - "bos_token_id": None, - "chunk_size_feed_forward": None, - "diversity_penalty": None, - "do_sample": None, - "early_stopping": None, - "encoder_no_repeat_ngram_size": None, - "eos_token_id": None, - "ep_size": None, - "first_k_dense_replace": None, - "forced_bos_token_id": None, - "forced_eos_token_id": None, - "hidden_act": None, - "hidden_size": None, - "id2label": None, - "initializer_range": None, - "intermediate_size": None, - "is_decoder": None, - "is_encoder_decoder": None, - "kv_lora_rank": None, - "label2id": None, - "length_penalty": None, - "max_length": None, - "max_position_embeddings": None, - "min_length": None, - "model_type": "kimi_k2", - "moe_intermediate_size": None, - "moe_layer_freq": None, - "n_group": None, - "n_routed_experts": None, - "n_shared_experts": None, - "no_repeat_ngram_size": None, - "norm_topk_prob": None, - "num_attention_heads": None, - "num_beam_groups": None, - "num_beams": None, - "num_experts_per_tok": None, - "num_hidden_layers": None, - "num_key_value_heads": None, - "num_nextn_predict_layers": None, - "num_return_sequences": None, - "output_attentions": None, - "output_hidden_states": None, - "output_scores": None, - "pad_token_id": None, - "pretraining_tp": None, - "pruned_heads": None, - "q_lora_rank": None, - "qk_nope_head_dim": None, - "qk_rope_head_dim": None, - "remove_invalid_values": None, - "repetition_penalty": None, - "return_dict": None, - "return_dict_in_generate": None, - "rms_norm_eps": None, - "rope_scaling": None, - "rope_theta": None, - "routed_scaling_factor": None, - "scoring_func": None, - "sep_token_id": None, - "seq_aux": None, - "temperature": None, - "tf_legacy_loss": None, - "tie_encoder_decoder": None, - "tie_word_embeddings": None, - "top_k": None, - "top_p": None, - "topk_group": None, - "topk_method": None, - "torch_dtype": None, - "torchscript": None, - "transformers_version": None, - "typical_p": None, - "use_bfloat16": None, - "use_cache": None, - "v_head_dim": None, - "vocab_size": None, - "eagle_config": { - "eagle_aux_hidden_state_layer_ids": None, - "use_aux_hidden_state": None, - "use_input_layernorm_in_first_layer": None, - "use_last_layernorm": None, - "use_mtp_layernorm": None, - "next_layer_regular": True, - "parallel_draft_step": None, - "parallel_draft_heads_num_layers": None, - }, - } - - template_config: dict = { - "llama": llama_eagle_template_config, - "kimik2": kimik2_eagle_template_config, - }[model.eagle_config.eagle_decoder_type] - - def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): - if getattr(model.eagle_config, key, None) is not None: - return getattr(model.eagle_config, key) - elif getattr(model.config, key, None) is not None: - return getattr(model.config, key) - else: - return None - - for key in template_config: - value = template_config[key] - if isinstance(value, dict): - # for eagle config, we find it in model.eagle_config - for sub_key in value: - if value[sub_key] is None: - value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model) - elif value is None: - # First, we try to load fron eagle config. - new_value = _get_config_from_eagle_config_or_base_config(key, model) - # If the value is a torch.dtype, we convert to string for serialization. - if isinstance(new_value, torch.dtype): - new_value = str(new_value).replace("torch.", "") - template_config[key] = new_value - - return template_config diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ca80cb450..352c347c8 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -81,7 +81,7 @@ QUANTIZATION_W4A8_NVFP4_FP8, ) from .model_utils import get_language_model_from_vl, is_multimodal_model -from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only +from .plugins import has_spec_opt from .quant_utils import ( fuse_prequant_layernorm, fuse_prequant_to_linear, @@ -98,7 +98,7 @@ to_quantized_weight, ) -__all__ = ["export_hf_checkpoint"] +__all__ = ["export_hf_checkpoint", "export_speculative_decoding"] def _is_enabled_quantizer(quantizer): @@ -978,6 +978,35 @@ def _export_diffusers_checkpoint( print(f"Export complete. Saved to: {export_dir}") +def export_speculative_decoding( + model: torch.nn.Module, + dtype: torch.dtype | None = None, + export_dir: Path | str = tempfile.gettempdir(), +) -> None: + """Export speculative decoding HuggingFace model checkpoint.""" + assert has_spec_opt(model), "Model is not optimized for speculative decoding." + + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + exporter = model.get_exporter(dtype) + + # Export state_dict + drafter_sd = exporter.extract_state_dict() + save_file(drafter_sd, f"{export_dir}/model.safetensors") + + # Export config.json + drafter_config = exporter.export_config(model) + with open(f"{export_dir}/config.json", "w") as file: + json.dump(drafter_config, file, indent=4) + + # Save hf_quant_config.json for backward compatibility + hf_quant_config = exporter.export_quant_config() + if hf_quant_config: + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4) + + def export_hf_checkpoint( model: Any, dtype: torch.dtype | None = None, @@ -1013,15 +1042,6 @@ def export_hf_checkpoint( _export_diffusers_checkpoint(model, dtype, export_dir, components) return - # Transformers model export - # NOTE: (hg) Early exit for speculative decoding models - # This is a temp workaround to avoid error with offline spec ckpt during export - if spec_opt_only(model): - save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") - with open(f"{export_dir}/config.json", "w") as file: - json.dump(export_spec_ckpt_config(model), file, indent=4) - return - try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index f8b7e33df..23d0254e8 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -50,6 +50,7 @@ from transformers.utils import ModelOutput from transformers.utils.quantization_config import QuantizationMethod +from ...export.plugins.hf_spec_export import EagleExporter, EagleMedusaExporter from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import expand_mask, make_causal_mask @@ -450,6 +451,18 @@ def _base_llm_config(self): or self.config ) + @property + def _draft_model_config(self): + """Return the llm config for the draft model.""" + return self.eagle_config + + def get_exporter(self, dtype: torch.dtype | None = None): + """Get the exporter for the draft model.""" + exporter_cls = ( + EagleExporter if self.eagle_config.parallel_draft_step <= 1 else EagleMedusaExporter + ) + return exporter_cls(self, dtype) + def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 4f80692ca..9c73ea96a 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -145,7 +145,7 @@ def test_export_hf_checkpoint(eagle_output_dir): # Check the exported checkpoints have required keys state_dict = safetensors.torch.load_file(eagle_output_dir / "eagle-tinyllama-export" / "model.safetensors") for required_key in LLAMA_EAGLE_SINGLE_LAYER["required"]: - assert required_key in state_dict, f"Missing key '{required_key}' in state_dict" + assert f"{required_key}.weight" in state_dict, f"Missing key '{required_key}.weight' in state_dict" def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir):