-
Notifications
You must be signed in to change notification settings - Fork 88
Repeatkv transform #997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Repeatkv transform #997
Changes from all commits
ffcd5a9
0c6100e
b738a8d
1272fcb
b40a34d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: | |
| self.model = model | ||
| self.config = model.config | ||
| self.hash_params = create_model_params(self, **kwargs) | ||
| self.hash_params["num_kv_heads_repeat"] = kwargs.get("num_kv_heads_repeat", 1) | ||
| self.onnx_path: Optional[str] = None | ||
| self.qpc_path: Optional[str] = None | ||
| self.qpc_session: Optional[QAICInferenceSession] = None | ||
|
|
@@ -440,23 +441,43 @@ def transform( | |
| **compiler_options, | ||
| ): | ||
| # Apply the transformations that are dependent on compilation parameters | ||
| def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: | ||
| """ | ||
| Use the shared wrapped model as transform-tracking root when available. | ||
| This lets encoder/decoder wrappers coordinate one-time transforms. | ||
| """ | ||
| wrapped = getattr(module, "model", None) | ||
| return wrapped if isinstance(wrapped, torch.nn.Module) else module | ||
|
|
||
| qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) | ||
|
|
||
| model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None) | ||
| model_config = getattr(self.model, "config", None) or getattr( | ||
| getattr(self.model, "model", None), "config", None | ||
| ) | ||
|
|
||
| if model_config: | ||
| if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): | ||
| if qaic_config: | ||
| if qaic_config.get("blocking_mode", None) == "h": | ||
| qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) | ||
| num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) | ||
| architectures = getattr(model_config, "architectures", None) or [] | ||
| is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures | ||
| if qaic_config: | ||
| if is_deepseek_v3 and (qaic_config.get("blocking_mode", None) == "h"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: for models w/mla and single kv heads, we do not want to replicate, ex: |
||
| qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) | ||
|
quic-dhirajku marked this conversation as resolved.
|
||
| num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) | ||
| transform_root = _transform_tracking_root(self.model) | ||
| applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) | ||
|
|
||
| if ReplicateKVHeadTransform.__name__ in applied_transforms: | ||
| replicate_kv_transformed = False | ||
| logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") | ||
| else: | ||
| self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( | ||
| self.model, num_kv_heads_repeat | ||
| self.model, | ||
| num_kv_heads_repeat=num_kv_heads_repeat, | ||
| ) | ||
| if replicate_kv_transformed: | ||
| self.hash_params["config"] = self.model.config.to_diff_dict() | ||
|
|
||
| applied_transforms.add(ReplicateKVHeadTransform.__name__) | ||
| setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms) | ||
| if replicate_kv_transformed: | ||
| self.hash_params["config"] = self.model.config.to_diff_dict() | ||
| blocking_config = build_transformer_blocking_config_for_transform( | ||
| model_config, | ||
| ctx_len=ctx_len, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1238,6 +1238,7 @@ def __init__( | |
| self.ccl_enabled = qaic_config.get("ccl_enabled", False) | ||
| self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None | ||
| self.input_shapes, self.output_names = None, None | ||
| # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) | ||
| # ---Sampling--- | ||
| # Note: SamplerTransform should be applied after all other transforms | ||
| # are done. The role of the sampler is to just add nodes at the output of the | ||
|
|
@@ -1273,6 +1274,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option | |
|
|
||
| kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
|
|
||
| num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) | ||
| model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
|
|
||
| kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) | ||
|
|
@@ -1281,6 +1283,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option | |
| model, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| qaic_config=qaic_config, | ||
| num_kv_heads_repeat=num_kv_heads_repeat, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -1371,7 +1374,12 @@ def export( | |
| if prefill_only and prefill_seq_len > 1: | ||
| offload_pt_weights = False # to keep weight for decode onnx | ||
| else: | ||
| offload_pt_weights = kwargs.get("offload_pt_weights", True) | ||
| num_kv_heads_repeat = ( | ||
| (self.lang_model.model.qaic_config or {}).get("num_kv_heads_repeat", 1) | ||
| if hasattr(self.lang_model.model, "qaic_config") | ||
| else 1 | ||
| ) | ||
| offload_pt_weights = kwargs.get("offload_pt_weights", num_kv_heads_repeat <= 1) | ||
|
|
||
| if not skip_lang: | ||
| self.lang_model.export( | ||
|
|
@@ -2037,6 +2045,7 @@ def __init__( | |
| self.model.config.text_config.use_cache = True | ||
| else: | ||
| self.model.config.use_cache = True | ||
| # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove commented code. |
||
| self.hash_params["qeff_auto_class"] = self.__class__.__name__ | ||
| self.ccl_enabled = False | ||
| if qaic_config: | ||
|
|
@@ -2086,6 +2095,7 @@ def from_pretrained( | |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) | ||
| config._attn_implementation = "eager" | ||
| config.vision_config.use_flash_attn = "false" | ||
| num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) | ||
| model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) | ||
|
|
||
| kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) | ||
|
|
@@ -2094,6 +2104,7 @@ def from_pretrained( | |
| model, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| qaic_config=qaic_config, | ||
| num_kv_heads_repeat=num_kv_heads_repeat, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -2698,6 +2709,7 @@ def from_pretrained( | |
| logger.warning("Updating low_cpu_mem_usage=False") | ||
|
|
||
| kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
| num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) | ||
| model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
|
|
||
| kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) | ||
|
|
@@ -2708,6 +2720,7 @@ def from_pretrained( | |
| continuous_batching=continuous_batching, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| qaic_config=qaic_config, | ||
| num_kv_heads_repeat=num_kv_heads_repeat, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -2867,6 +2880,7 @@ def __init__( | |
| setattr(self.model, "mla_absorption", mla_absorption) | ||
| self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None | ||
| self.hash_params["max_seq_len_cached"] = max_seq_len_cached | ||
| # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) | ||
|
|
||
| # ---Sampling--- | ||
| # Note: SamplerTransform should be applied after all other transforms | ||
|
|
@@ -2950,6 +2964,7 @@ def from_pretrained( | |
| kv_offload = kwargs.pop("kv_offload", None) | ||
|
|
||
| kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
| num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) | ||
| model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | ||
| if qaic_config is not None: | ||
| qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path | ||
|
|
@@ -2963,6 +2978,7 @@ def from_pretrained( | |
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| qaic_config=qaic_config, | ||
| continuous_batching=continuous_batching, | ||
| num_kv_heads_repeat=num_kv_heads_repeat, | ||
| **kwargs, | ||
| ) | ||
| return cls( | ||
|
|
@@ -2971,6 +2987,7 @@ def from_pretrained( | |
| qaic_config=qaic_config, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| max_seq_len_cached=max_seq_len_cached, | ||
| num_kv_heads_repeat=num_kv_heads_repeat, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the lines 459-463, not needed.