From d37ef9dadfe13352cb4b98301f4a7d6a0341c103 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 28 Jan 2026 07:37:28 +0000 Subject: [PATCH 01/51] added modeling for kimik2 Signed-off-by: Onkar Chougule --- .../models/deepseek_v3/__init__.py | 7 + .../deepseek_v3/configuration_deepseek.py | 212 ++ .../models/deepseek_v3/modeling_deepseek.py | 1849 +++++++++++++++++ .../deepseek_v3/modeling_deepseek_orig.py | 1667 +++++++++++++++ .../deepseek_v3/modeling_deepseek_qeff.py | 743 +++++++ .../transformers/models/pytorch_transforms.py | 17 +- examples/run_kimik2.py | 57 + pyproject.toml | 2 + 8 files changed, 4552 insertions(+), 2 deletions(-) create mode 100644 QEfficient/transformers/models/deepseek_v3/__init__.py create mode 100644 QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py create mode 100755 QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py create mode 100644 QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py create mode 100644 QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py create mode 100644 examples/run_kimik2.py diff --git a/QEfficient/transformers/models/deepseek_v3/__init__.py b/QEfficient/transformers/models/deepseek_v3/__init__.py new file mode 100644 index 0000000000..da26921c50 --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py b/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py new file mode 100644 index 0000000000..ece0a5e075 --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py @@ -0,0 +1,212 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="noaux_tc", + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func="sigmoid", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py new file mode 100755 index 0000000000..1192a0063d --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -0,0 +1,1849 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV3Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV3Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.n_routed_experts)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "noaux_tc": + assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 +class DeepseekV3FlashAttention2(DeepseekV3Attention): + """ + DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV3FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV3Attention, + "flash_attention_2": DeepseekV3FlashAttention2, +} + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + + Args: + config: DeepseekV3Config + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py new file mode 100644 index 0000000000..8855ee88b5 --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py @@ -0,0 +1,1667 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DeepSeek model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available + +from .configuration_deepseek import DeepseekV3Config + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV3Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + if self.topk_method == "noaux_tc": + assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + # flat_topk_idx = topk_idx.view(-1) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 +class DeepseekV3FlashAttention2(DeepseekV3Attention): + """ + DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV3FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV3Attention, + "flash_attention_2": DeepseekV3FlashAttention2, +} + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + + Args: + config: DeepseekV3Config + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py new file mode 100644 index 0000000000..46ba8f55ed --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -0,0 +1,743 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + +# Assuming these are imported from the original DeepseekV3 code +# from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( +# DeepseekV3Config, +# DeepseekV3RMSNorm, +# DeepseekV3MLP, +# DeepseekV3MoE, +# rotate_half, +# repeat_kv, +# DeepseekV3Attention, +# DeepseekV3DecoderLayer, +# DeepseekV3Model, +# DeepseekV3ForCausalLM, +# DeepseekV3PreTrainedModel, +# logger, +# ) +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.cache_utils import QEffDynamicCache + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +class QEffDeepseekV3RotaryEmbedding(nn.Module): + """ + Adapted from DeepseekV3RotaryEmbedding with static sin/cos caches like QEffLlamaRotaryEmbedding. + """ + + def __init__(self, config, device=None): + super().__init__() + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Precompute static sin/cos caches + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, position_ids): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + # Use position_ids to slice the precomputed caches + cos = self.cos_cached[position_ids] * self.attention_scaling + sin = self.sin_cached[position_ids] * self.attention_scaling + return cos.to(x.dtype), sin.to(x.dtype) + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # Slice cos and sin using position_ids if they are larger (e.g., precomputed caches) + if cos.shape[-2] > q.shape[-2]: + cos = cos[:, position_ids, :] + sin = sin[:, position_ids, :] + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Adapted from DeepseekV3's apply_rotary_pos_emb for QEff compatibility with position_ids slicing.""" + # Slice cos and sin using position_ids if they are larger (e.g., precomputed caches) + if cos.shape[-2] > q.shape[-2]: + cos = cos[:, position_ids, :] + sin = sin[:, position_ids, :] + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + + +class QEffDeepseekV3Attention(nn.Module): + """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + cos, sin = position_embeddings + # if self.config.rope_interleave: + # q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + # else: + q_rot, k_rot = orig_apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_length, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class QEffDeepseekV3MoE(nn.Module): + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + # breakpoint() + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + # breakpoint() + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + # token_indices, weight_indices = torch.where(mask) + # if token_indices.numel() > 0: + # if torch.sum(mask).item() > 0: + # expert_weights = topk_weights[token_indices, weight_indices] + # expert_input = hidden_states[token_indices] + # expert_output = expert(expert_input) + expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None]) + # weighted_output = expert_output * expert_weights.unsqueeze(-1) + # final_hidden_states.index_add_(0, token_indices, weighted_output) + expert_output = torch.where( + (topk_weights * mask).sum(1).to(torch.bool)[:, None], + expert_output, + torch.tensor(0.0), + ) + final_hidden_states = final_hidden_states + expert_output + return final_hidden_states.type(hidden_states.dtype) + + +class QEffDeepseekV3DecoderLayer(nn.Module): + """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffDeepseekV3Model(nn.Module): + """Adapted DeepseekV3Model with batch_index and QEff rotary embedding.""" + + def __qeff_init__(self): + scaling_factor = self.config.rope_scaling["factor"] + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=32*1024, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # breakpoint() + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and not isinstance(past_key_values, Cache): + # past_key_values = ( + # DynamicCache() if past_key_values is None else DynamicCache.from_legacy_cache(past_key_values) + # ) + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values[0][0].shape[2]) + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + batch_index, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + next_cache = next_cache.to_legacy_cache() + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + # causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + # attention_mask, + # sequence_length=sequence_length, + # target_length=target_length, + # dtype=dtype, + # device=device, + # cache_position=cache_position, + # batch_size=input_tensor.shape[0], + # ) + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class QEffDeepseekV3ForCausalLM(nn.Module): + """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 31c86a9c72..d77d1b8e0c 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -249,6 +249,7 @@ from QEfficient.transformers.models.deberta_v2.modeling_deberta_v2 import ( QEffDisentangledSelfAttention, ) +from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3MoE, QEffDeepseekV3Model from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconDecoderLayer, @@ -942,10 +943,22 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "RMSNorm": { "forward": QEFFGrok1CustomRMSNormAIC.forward, }, + "DeepseekV3ForCausalLM":{ + "forward": QEffDeepseekV3ForCausalLM.forward, + }, + "DeepseekV3Model":{ + "forward": QEffDeepseekV3Model.forward, + "__qeff_init__": QEffDeepseekV3Model.__qeff_init__ + }, + "DeepseekV3DecoderLayer": { + "forward": QEffDeepseekV3DecoderLayer.forward, + }, + "DeepseekV3MoE": { + "forward": QEffDeepseekV3MoE.forward, + "moe": QEffDeepseekV3MoE.moe + }, } - _match_class_replace_method = {} - class T5ModelTransform(ModuleMappingTransform): # supported architectures diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py new file mode 100644 index 0000000000..01b33c544f --- /dev/null +++ b/examples/run_kimik2.py @@ -0,0 +1,57 @@ +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float16, num_hidden_layers=2, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) +PREFILL_SEQ_LEN=128 + + +prompts = "Once upon a time," +inputs = tokenizer(prompts, return_tensors="pt", padding=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + +with torch.no_grad(): + out = model(**inputs) + predictions = torch.argmax(out.logits, dim=-1) + + +qeff_model = QEFFAutoModelForCausalLM(model) + +inputs = tokenizer(prompts, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(model.config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + +qeff_out = qeff_model.model(**inputs) + +assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 +# qeff_model = QEFFAutoModelForCausalLM(model) +# qeff_model.compile( +# prefill_seq_len=1, +# num_devices=1, +# use_onnx_subfunctions=True, +# ctx_len=8192, +# mxfp6_matmul=True, +# # mxint8_kv_cache=True, +# mos=1, +# aic_enable_depth_first=True, +# num_cores=16, +# offload_pt_weights=True, +# ) +# tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.7") +# qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 62636f96ae..5e5060f5c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ "ftfy==6.3.1", "imageio==2.37.2", "imageio-ffmpeg==0.6.0", + "tiktoken", + "compressed-tensors", "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", From 426eed13f756433edd1df4e79c7876e2be545145 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 28 Jan 2026 19:10:15 +0000 Subject: [PATCH 02/51] able to run kimi model, need to check accuracy Signed-off-by: Onkar Chougule --- .../deepseek_v3/modeling_deepseek_qeff.py | 90 +++++++++++++------ .../models/grok_1/modeling_grok1.py | 2 +- .../transformers/models/modeling_auto.py | 17 ++-- .../transformers/models/pytorch_transforms.py | 9 +- examples/run_kimik2.py | 12 +-- 5 files changed, 87 insertions(+), 43 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 46ba8f55ed..602d34b82c 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -128,6 +128,16 @@ def forward(self, x, seq_len=None): self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) + + # def forward(self, x, position_ids): + # seq_len = torch.max(position_ids) + 1 + # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + # # Use position_ids to slice the precomputed caches + # cos = self.cos_cached[position_ids] + # sin = self.sin_cached[position_ids] + # return cos.to(x.dtype), sin.to(x.dtype) @@ -315,6 +325,7 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ + # import ipdb; ipdb.set_trace() cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) @@ -332,6 +343,11 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" + + def __qeff_init__(self,): + self.merged_q_weight = torch.matmul(self.q_a_proj.weight.T * self.q_a_layernorm.weight.unsqueeze(0), self.q_b_proj.weight.T) + self.merged_k_weight = torch.matmul(self.k_a_proj.weight.T * self.k_a_layernorm.weight.unsqueeze(0), self.k_b_proj.weight.T) + self.merged_v_weight = torch.matmul(self.v_a_proj.weight.T * self.v_a_layernorm.weight.unsqueeze(0), self.v_b_proj.weight.T) def forward( self, @@ -346,39 +362,53 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - batch_size, seq_length = hidden_states.shape[:-1] - query_shape = (batch_size, seq_length, -1, self.qk_head_dim) - key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) - - q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) - q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + bsz, q_len, _ = hidden_states.size() + # import ipdb; ipdb.set_trace() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - - k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) - k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) - cos, sin = position_embeddings - # if self.config.rope_interleave: - # q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) - # else: - q_rot, k_rot = orig_apply_rotary_pos_emb(q_rot, k_rot, cos, sin) - - k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) - query_states = torch.cat((q_pass, q_rot), dim=-1) - key_states = torch.cat((k_pass, k_rot), dim=-1) + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + # cos, sin = position_embeddings + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=32*1024) + + # q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + # import ipdb; ipdb.set_trace() if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) @@ -387,7 +417,7 @@ def forward( attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_length, -1) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -497,12 +527,13 @@ def __qeff_init__(self): if key in self.config.rope_scaling } self.rotary_emb = DeepseekV3YarnRotaryEmbedding( - self.qk_rope_head_dim, + self.config.qk_rope_head_dim, max_position_embeddings=32*1024, scaling_factor=scaling_factor, - base=self.rope_theta, + base=self.config.rope_theta, **kwargs, ) + # import ipdb; ipdb.set_trace() def forward( self, @@ -552,7 +583,8 @@ def forward( causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values[0][0].shape[2]) hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = None all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 51bdaa4ea4..0f88fe1b92 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -39,7 +39,7 @@ def forward(self, hidden_states): torch.Tensor: Normalized tensor. """ return CustomRMSNormFunc.apply( - hidden_states, self.scale, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9257ff114f..61175dc379 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2989,9 +2989,11 @@ def export( bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - kv_cache_shape = get_padding_shape_from_config( - self.model.config, fbs if self.continuous_batching else bs, seq_len - ) + # kv_cache_shape = get_padding_shape_from_config( + # self.model.config, fbs if self.continuous_batching else bs, seq_len + # ) + kv_cache_shape = (1, 64, seq_len, 192) + kv_cache_shape_v = (1, 64, seq_len, 128) enable_chunking = kwargs.get("enable_chunking", False) if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: @@ -3085,10 +3087,11 @@ def export( ) for i in range(self.num_layers): - for kv in ["key", "value"]: - example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_{kv}.{i}_RetainedState") + # for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape_v, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index d77d1b8e0c..5e87779a7a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -249,7 +249,7 @@ from QEfficient.transformers.models.deberta_v2.modeling_deberta_v2 import ( QEffDisentangledSelfAttention, ) -from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3MoE, QEffDeepseekV3Model +from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import QEffDeepseekV3Attention, QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3MoE, QEffDeepseekV3Model from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconDecoderLayer, @@ -888,6 +888,7 @@ class VlmNoKVOffloadTransform(ModuleMappingTransform): class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} _match_string_replace_method = { "InternVLChatModel": { "forward": QEffInternVLModel.forward, @@ -957,6 +958,12 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "forward": QEffDeepseekV3MoE.forward, "moe": QEffDeepseekV3MoE.moe }, + "DeepseekV3Attention":{ + "forward": QEffDeepseekV3Attention.forward + }, + "DeepseekV3RMSNorm":{ + "forward": QEFFGrok1CustomRMSNormAIC.forward, + } } diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 01b33c544f..78479ce3e8 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -22,21 +22,23 @@ qeff_model = QEFFAutoModelForCausalLM(model) - +import ipdb; ipdb.set_trace() inputs = tokenizer(prompts, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs.pop("token_type_ids", None) inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} past_key_values = [] for i in range(model.config.num_hidden_layers): - cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN - pad_shape = (1, 8, cache_len, 64) - past_key = torch.zeros((pad_shape), dtype=torch.float32) - past_value = torch.zeros((pad_shape), dtype=torch.float32) + cache_len = 128 + pad_shape_k = (1, 64, cache_len, 192) + pad_shape_v = (1, 64, cache_len, 128) + past_key = torch.zeros((pad_shape_k), dtype=torch.float16) + past_value = torch.zeros((pad_shape_v), dtype=torch.float16) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = past_key_values +qeff_model.compile(prefill_seq_len=1, ctx_len=1024, mxfp6_matmul=True, num_devices=1) qeff_out = qeff_model.model(**inputs) assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 From f3de7230c09a711decd24ffb9e93c18b07075c27 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 2 Feb 2026 09:49:45 +0000 Subject: [PATCH 03/51] bugfix Signed-off-by: Onkar Chougule --- .../deepseek_v3/modeling_deepseek_qeff.py | 40 +++++++++++++++++++ .../transformers/models/modeling_auto.py | 6 ++- examples/run_kimik2.py | 2 +- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 602d34b82c..b48ab0b1aa 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -427,6 +427,46 @@ def forward( class QEffDeepseekV3MoE(nn.Module): + def __qeff_init__( + self, + ): + self.all_gate_proj = torch.nn.Parameter( + torch.cat([exp.gate_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + ) + self.all_up_proj = torch.nn.Parameter( + torch.cat([exp.up_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + ) + self.all_down_proj = torch.nn.Parameter( + torch.cat([exp.down_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + ) + self.act_fn = self.experts[0].act_fn + + def moe( + self, + hidden_states: torch.Tensor, + topk_indices: torch.Tensor, + topk_weights: torch.Tensor, + ): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + gate_proj = self.all_gate_proj[topk_indices.flatten()] + up_proj = self.all_up_proj[topk_indices.flatten()] + down_proj = self.all_down_proj[topk_indices.flatten()] + expert_in = ( + hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size) + ) + gate_out = torch.bmm(expert_in, gate_proj) + up_out = torch.bmm(expert_in, up_proj) + hidden = self.act_fn(gate_out) * up_out + expert_output = torch.bmm(hidden, down_proj) + experts_out = expert_output.view(bs * seq_len, self.gate.top_k, self.config.hidden_size) + experts_out = experts_out * topk_weights.unsqueeze(-1) + # final_hidden_states = experts_out.sum(dim=1) + final_hidden_states = torch.einsum("abc->ac", experts_out) + + return final_hidden_states.type(hidden_states.dtype) + def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 61175dc379..3ae96aebec 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3090,8 +3090,10 @@ def export( # for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape_v, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_{kv}.{i}_RetainedState") + dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes[i] + dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_key.{i}_RetainedState") + output_names.append(f"past_value.{i}_RetainedState") if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 78479ce3e8..98bfa1252b 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -38,7 +38,7 @@ past_key_values.append(pkv) inputs["past_key_values"] = past_key_values -qeff_model.compile(prefill_seq_len=1, ctx_len=1024, mxfp6_matmul=True, num_devices=1) +# qeff_model.compile(prefill_seq_len=1, ctx_len=1024, mxfp6_matmul=True, num_devices=1) qeff_out = qeff_model.model(**inputs) assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 From a3abf05c0dee8c8e0ea6e5729c517148ffcb486e Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Fri, 6 Feb 2026 09:51:59 +0000 Subject: [PATCH 04/51] experimentation branch commit Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 8 + QEfficient/transformers/cache_utils.py | 25 +++ .../deepseek_v3/modeling_deepseek_qeff.py | 155 +++++++++++++----- .../transformers/models/modeling_auto.py | 29 ++++ .../transformers/models/pytorch_transforms.py | 4 +- examples/run_kimik2.py | 21 ++- 6 files changed, 193 insertions(+), 49 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6f22e867ef..b279bce1f5 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -328,11 +328,15 @@ def get_onnx_path( offload_pt_weights: Optional[bool] = True, use_onnx_subfunctions: Optional[bool] = False, retain_full_kv: Optional[bool] = False, + enable_mla: Optional[bool] = False, + enable_mla_absorption: Optional[bool] = False, ): kwargs = { "offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions, "retain_full_kv": retain_full_kv, + "enable_mla": enable_mla, + "enable_mla_absorption": enable_mla_absorption, } if prefill_only: @@ -365,6 +369,8 @@ def _compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, + enable_mla: Optional[bool] = False, + enable_mla_absorption: Optional[bool] = False, **compiler_options, ) -> str: """ @@ -402,6 +408,8 @@ def _compile( offload_pt_weights, use_onnx_subfunctions, retain_full_kv, + enable_mla, + enable_mla_absorption, ) ) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 6ebccdfbf8..5f050f7cf5 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -343,6 +343,31 @@ def update3D( return k_out, v_out +class QEffDynamicCompressedKVLayer: + def __init__(self, ckv): + pass + + + +class QEffDynamicCompressedKVCache: + def __init__(self, ddp_cache_data = None, *args, **kwargs): + super().__init__(ddp_cache_data, *args, **kwargs) + self.layers=[] + + def update(self, ckv, layer_idx): + self.layers.append(QEffDynamicCompressedKVCache()) + + @classmethod + def from_legacy_cache(cls, past_key_values): + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + ckv = past_key_values[layer_idx] + cache.update(ckv, layer_idx) + return cache + + + class QEffDynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index b48ab0b1aa..02e2dfbe64 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -343,11 +343,74 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" - + def __qeff_init__(self,): - self.merged_q_weight = torch.matmul(self.q_a_proj.weight.T * self.q_a_layernorm.weight.unsqueeze(0), self.q_b_proj.weight.T) - self.merged_k_weight = torch.matmul(self.k_a_proj.weight.T * self.k_a_layernorm.weight.unsqueeze(0), self.k_b_proj.weight.T) - self.merged_v_weight = torch.matmul(self.v_a_proj.weight.T * self.v_a_layernorm.weight.unsqueeze(0), self.v_b_proj.weight.T) + self.q_up, self.q_rope = self.q_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.qk_rope_head_dim).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + self.q_up = self.q_up.reshape(-1, self.num_heads*self.qk_nope_head_dim) + self.q_rope = self.q_rope.reshape(-1, self.num_heads* self.qk_rope_head_dim) + self.k_up, self.v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.v_head_dim).split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + self.k_up = self.k_up.reshape(-1, self.num_heads*self.qk_nope_head_dim) + self.v_up = self.v_up.reshape(-1, self.num_heads*self.v_head_dim) + self.fusedqk = torch.bmm(self.q_up.view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1), self.k_up.view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2)) + + def fused_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + import ipdb; ipdb.set_trace() + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope).view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + q_nope = torch.matmul(q_a_proj_out, self.q_up).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(compressed_kv) + k_nope = torch.matmul(kva, self.k_up).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + value_states = torch.matmul(kva, self.v_up).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, seq_len=32*1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if enable_mla_absorption: + atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1)) + else: + atn = torch.matmul(q_nope, k_nope.transpose(2, 3)) + + atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3)) + attn_weights = (atn+atr) * self.softmax_scale + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + ctx_len = past_key_value[self.layer_idx][0].shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + value_states = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), value_states) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + # if not output_attentions: + # attn_weights = None + + return attn_output, attn_weights, past_key_value, value_states def forward( self, @@ -388,21 +451,13 @@ def forward( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - - # cos, sin = position_embeddings - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=32*1024) - # q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.repeat(1,self.num_heads,1,1) + key_states = torch.cat((k_nope, k_pe_new), -1) # import ipdb; ipdb.set_trace() if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} @@ -420,10 +475,10 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None + # if not output_attentions: + # attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, value_states class QEffDeepseekV3MoE(nn.Module): @@ -517,23 +572,41 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + enable_mla: Optional[bool] = False, + enable_mla_absorption: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + if enable_mla: + hidden_states, self_attn_weights, present_key_value, vs = self.self_attn.fused_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + enable_mla_absorption=enable_mla_absorption, + **kwargs, + ) - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - position_embeddings=position_embeddings, - past_key_value=past_key_value, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) + else: + hidden_states, self_attn_weights, present_key_value, vs = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + import ipdb; ipdb.set_trace() hidden_states = residual + hidden_states residual = hidden_states @@ -580,6 +653,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + compressed_kvs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -590,7 +664,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - # breakpoint() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -601,17 +674,17 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and not isinstance(past_key_values, Cache): - # past_key_values = ( - # DynamicCache() if past_key_values is None else DynamicCache.from_legacy_cache(past_key_values) - # ) + if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - + enable_mla = getattr(self, "enable_mla", False) + if enable_mla: + compressed_kvs = QEffDynamicCache.from_legacy_cache(compressed_kvs) + + + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -623,7 +696,6 @@ def forward( causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values[0][0].shape[2]) hidden_states = inputs_embeds - # position_embeddings = self.rotary_emb(hidden_states, position_ids) position_embeddings = None all_hidden_states = () if output_hidden_states else None @@ -652,12 +724,15 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + compressed_kvs = compressed_kvs, past_key_value=past_key_values, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + enable_mla = getattr(self, "enable_mla", False), + enable_mla_absorption = getattr(self, "enable_mla_absorption", False), **kwargs, ) @@ -755,6 +830,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + compressed_kvs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -777,6 +853,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + compressed_kvs = compressed_kvs, past_key_values=past_key_values, batch_index=batch_index, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 3ae96aebec..c2a63fefc0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2995,6 +2995,16 @@ def export( kv_cache_shape = (1, 64, seq_len, 192) kv_cache_shape_v = (1, 64, seq_len, 128) enable_chunking = kwargs.get("enable_chunking", False) + + # TODO: HACK handle better + if enable_mla:=kwargs.get('enable_mla', False): + self.hash_params['enable_mla'] = enable_mla + setattr(self.model.model, "enable_mla", enable_mla) + if enable_mla_absorption:=kwargs.get('enable_mla_absorption', False): + self.hash_params['enable_mla_absorption'] = enable_mla_absorption + setattr(self.model.model, "enable_mla_absorption", enable_mla_absorption) + + # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: if not enable_chunking and self.continuous_batching: @@ -3113,6 +3123,18 @@ def export( vocab_size=self.model.config.vocab_size, qaic_config=self.model.qaic_config, ) + if enable_mla: + [example_inputs.pop(k) for k in example_inputs.keys() if "past" in k] + [dynamic_axes.pop(k) for k in dynamic_axes.keys() if "past" in k] + output_names = [v for v in output_names if "past" not in v] + example_inputs['compressed_kv'] = [] + for i in range(self.num_layers): + example_inputs['compressed_kv'][i] = torch.zeros((bs, seq_len, self.model.config.kv_lora_rank + self.model.config.qk_rope_head_dim), dtype=torch.float32) + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "seq_len"} + output_names.append(f"compressed_kv.{i}_RetainedState") + + + return self._export( example_inputs, output_names=output_names, @@ -3266,9 +3288,12 @@ def compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, + enable_mla: Optional[bool] = False, + enable_mla_absorption: Optional[bool] = False, **compiler_options, ) -> str: """ + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. This method generates a ``qpc`` package. If the model has not been exported yet, @@ -3346,6 +3371,8 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if enable_mla_absorption and not enable_mla: + logger.warning("enable_mla_fusion will be ignored as enable_mla is set to False") if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: logger.warning( "`kv_cache_batch_size` or `full_batch_size` is being passed" @@ -3495,6 +3522,8 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, + enable_mla = enable_mla, + enable_mla_absorption = enable_mla_absorption, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5e87779a7a..bc2d6a0707 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -959,7 +959,9 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "moe": QEffDeepseekV3MoE.moe }, "DeepseekV3Attention":{ - "forward": QEffDeepseekV3Attention.forward + "forward": QEffDeepseekV3Attention.forward, + "fused_forward": QEffDeepseekV3Attention.fused_forward, + "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, "DeepseekV3RMSNorm":{ "forward": QEFFGrok1CustomRMSNormAIC.forward, diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 98bfa1252b..64812371eb 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -1,10 +1,13 @@ import numpy as np import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from QEfficient import QEFFAutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float16, num_hidden_layers=2, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained("moonshotai/Kimi-K2-Thinking", num_hidden_layers=1, trust_remote_code=True, torch_dtype=torch.float32) +import ipdb; ipdb.set_trace() +#model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float16, num_hidden_layers=2, trust_remote_code=True) +# model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) PREFILL_SEQ_LEN=128 @@ -16,9 +19,9 @@ padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len -with torch.no_grad(): - out = model(**inputs) - predictions = torch.argmax(out.logits, dim=-1) +# with torch.no_grad(): +# out = model(**inputs) +# predictions = torch.argmax(out.logits, dim=-1) qeff_model = QEFFAutoModelForCausalLM(model) @@ -32,8 +35,8 @@ cache_len = 128 pad_shape_k = (1, 64, cache_len, 192) pad_shape_v = (1, 64, cache_len, 128) - past_key = torch.zeros((pad_shape_k), dtype=torch.float16) - past_value = torch.zeros((pad_shape_v), dtype=torch.float16) + past_key = torch.zeros((pad_shape_k), dtype=torch.float32) + past_value = torch.zeros((pad_shape_v), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = past_key_values @@ -41,7 +44,7 @@ # qeff_model.compile(prefill_seq_len=1, ctx_len=1024, mxfp6_matmul=True, num_devices=1) qeff_out = qeff_model.model(**inputs) -assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 +# assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 # qeff_model = QEFFAutoModelForCausalLM(model) # qeff_model.compile( # prefill_seq_len=1, @@ -56,4 +59,4 @@ # offload_pt_weights=True, # ) # tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.7") -# qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) \ No newline at end of file +# qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From a27e2720335f9fa03bde273736c596a579ec836f Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 9 Feb 2026 18:53:31 +0000 Subject: [PATCH 05/51] added MLA with/WO fusion, the caching for different config needs to be sorted Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 12 +- QEfficient/transformers/cache_utils.py | 45 ++- .../deepseek_v3/modeling_deepseek_qeff.py | 263 +++++++++--------- .../transformers/models/modeling_auto.py | 41 +-- .../transformers/models/pytorch_transforms.py | 3 +- examples/export_kimik2.py | 18 ++ 6 files changed, 214 insertions(+), 168 deletions(-) create mode 100644 examples/export_kimik2.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b279bce1f5..1300f70558 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -269,9 +269,13 @@ def _export( raise ValueError( f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {len(example_inputs['past_key_values'][0])}" ) + elif param == "compressed_kvs": + for i in range(len(example_inputs["compressed_kvs"])): + input_names.extend([f"compressed_kvs.{i}",]) else: input_names.append(param) + import ipdb; ipdb.set_trace() try: torch.onnx.export( self.model, @@ -329,14 +333,14 @@ def get_onnx_path( use_onnx_subfunctions: Optional[bool] = False, retain_full_kv: Optional[bool] = False, enable_mla: Optional[bool] = False, - enable_mla_absorption: Optional[bool] = False, + mla_absorption_config: Optional[bool] = False, ): kwargs = { "offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions, "retain_full_kv": retain_full_kv, "enable_mla": enable_mla, - "enable_mla_absorption": enable_mla_absorption, + "mla_absorption_config": mla_absorption_config, } if prefill_only: @@ -370,7 +374,7 @@ def _compile( enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, enable_mla: Optional[bool] = False, - enable_mla_absorption: Optional[bool] = False, + mla_absorption_config: Optional[Dict[str, bool]] = False, **compiler_options, ) -> str: """ @@ -409,7 +413,7 @@ def _compile( use_onnx_subfunctions, retain_full_kv, enable_mla, - enable_mla_absorption, + mla_absorption_config, ) ) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 5f050f7cf5..9fc3936208 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -345,17 +345,40 @@ def update3D( class QEffDynamicCompressedKVLayer: def __init__(self, ckv): - pass - - + self.ckv = ckv + + def update(self, compressed_kv, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later + + self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv) + + ckv_out = self.ckv + ctx_len = ckv_out.shape[1] + ctx_indices = torch.arange(ctx_len)[None, ...] + gather_limit = position_ids.max(1, keepdim=True).values + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices) + ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) + return ckv_out + + class QEffDynamicCompressedKVCache: - def __init__(self, ddp_cache_data = None, *args, **kwargs): - super().__init__(ddp_cache_data, *args, **kwargs) + def __init__(self,): self.layers=[] - def update(self, ckv, layer_idx): - self.layers.append(QEffDynamicCompressedKVCache()) + def add_new(self, ckv, layer_idx): + self.layers.append(QEffDynamicCompressedKVLayer(ckv)) + + def update(self, ckv, layer_idx, cache_kwargs): + return self.layers[layer_idx].update(ckv, cache_kwargs) @classmethod def from_legacy_cache(cls, past_key_values): @@ -363,9 +386,15 @@ def from_legacy_cache(cls, past_key_values): if past_key_values is not None: for layer_idx in range(len(past_key_values)): ckv = past_key_values[layer_idx] - cache.update(ckv, layer_idx) + cache.add_new(ckv, layer_idx) return cache + def to_legacy_cache(self, ): + legacy_cache = () + for layer in self.layers: + legacy_cache += (layer.ckv,) + return legacy_cache + class QEffDynamicCache(Cache): diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 02e2dfbe64..1c67ba11aa 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -25,7 +25,7 @@ # logger, # ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVCache def rotate_half(x): @@ -345,13 +345,30 @@ class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" def __qeff_init__(self,): - self.q_up, self.q_rope = self.q_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.qk_rope_head_dim).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - self.q_up = self.q_up.reshape(-1, self.num_heads*self.qk_nope_head_dim) - self.q_rope = self.q_rope.reshape(-1, self.num_heads* self.qk_rope_head_dim) - self.k_up, self.v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.v_head_dim).split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - self.k_up = self.k_up.reshape(-1, self.num_heads*self.qk_nope_head_dim) - self.v_up = self.v_up.reshape(-1, self.num_heads*self.v_head_dim) - self.fusedqk = torch.bmm(self.q_up.view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1), self.k_up.view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2)) + q_up, q_rope = self.q_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.qk_rope_head_dim).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_up = q_up.reshape(-1, self.num_heads*self.qk_nope_head_dim).unsqueeze(0) + # self.register_buffer("q_up", q_up.detach().clone(), persistent=False) + self.q_up = torch.nn.Parameter(q_up.detach().clone()) + q_rope = q_rope.reshape(-1, self.num_heads* self.qk_rope_head_dim).unsqueeze(0) + # self.register_buffer("q_rope", q_rope.detach().clone(), persistent=False) + self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.v_head_dim).split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_up = k_up.reshape(-1, self.num_heads*self.qk_nope_head_dim).unsqueeze(0) + v_up = v_up.reshape(-1, self.num_heads*self.v_head_dim).unsqueeze(0) + # self.register_buffer("k_up", k_up.detach().clone(), persistent=False) + # self.register_buffer("v_up", v_up.detach().clone(), persistent=False) + self.k_up = torch.nn.Parameter(k_up.detach().clone()) + self.v_up = torch.nn.Parameter(v_up.detach().clone()) + per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + per_head_k_up = self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) + # import ipdb; ipdb.set_trace() + # self.register_buffer("per_head_q_up", per_head_q_up.detach().clone(), persistent=False) + # self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False) + self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) + self.per_haed_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) + fusedqk = torch.bmm(per_head_q_up, per_head_k_up) + # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) + self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) def fused_forward( self, @@ -360,34 +377,65 @@ def fused_forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - import ipdb; ipdb.set_trace() + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + # import ipdb; ipdb.set_trace() + if compressed_kvs is not None: + cache_kwargs = { + "position_ids": position_ids, + "batch_index": batch_index, + } + compressed_kv = compressed_kvs.update(compressed_kv, self.layer_idx, cache_kwargs) + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.matmul(q_a_proj_out, self.q_rope).view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.matmul(q_a_proj_out, self.q_up).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + # import ipdb; ipdb.set_trace() + q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + + compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kva = self.kv_a_layernorm(compressed_kv) - k_nope = torch.matmul(kva, self.k_up).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - value_states = torch.matmul(kva, self.v_up).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + kva = self.kv_a_layernorm(compressed_kv) + k_nope = torch.bmm(kva, self.k_up) + k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + value_states = torch.bmm(kva, self.v_up) + value_states = value_states.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, seq_len=32*1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - if enable_mla_absorption: - atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1)) + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) + else: + enable_absorption = False + + if enable_absorption: + if absorb_online: + print("online absorption") + + atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)), kva.transpose(1, 2).unsqueeze(1)) + else: + + print("using fused qk") + atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1)) else: + print("no absorption") atn = torch.matmul(q_nope, k_nope.transpose(2, 3)) atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3)) @@ -396,12 +444,12 @@ def fused_forward( if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - ctx_len = past_key_value[self.layer_idx][0].shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - value_states = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), value_states) - attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + # ctx_len = past_key_value[self.layer_idx][0].shape[2] + # ctx_indices = torch.arange(ctx_len)[None, None, ...] + # gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + # invalid_mask = ctx_indices > gather_limit + # value_states = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), value_states) + # attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) @@ -410,7 +458,7 @@ def fused_forward( # if not output_attentions: # attn_weights = None - return attn_output, attn_weights, past_key_value, value_states + return attn_output, attn_weights, compressed_kvs, value_states def forward( self, @@ -426,7 +474,6 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - # import ipdb; ipdb.set_trace() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: @@ -475,8 +522,6 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - # if not output_attentions: - # attn_weights = None return attn_output, attn_weights, past_key_value, value_states @@ -486,13 +531,21 @@ def __qeff_init__( self, ): self.all_gate_proj = torch.nn.Parameter( - torch.cat([exp.gate_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + torch.cat( + [exp.gate_proj.compressor.decompress_module(exp.gate_proj).T.unsqueeze(0) for exp in self.experts], + dim=0, + ) ) self.all_up_proj = torch.nn.Parameter( - torch.cat([exp.up_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + torch.cat( + [exp.up_proj.compressor.decompress_module(exp.up_proj).T.unsqueeze(0) for exp in self.experts], dim=0 + ) ) self.all_down_proj = torch.nn.Parameter( - torch.cat([exp.down_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + torch.cat( + [exp.down_proj.compressor.decompress_module(exp.down_proj).T.unsqueeze(0) for exp in self.experts], + dim=0, + ) ) self.act_fn = self.experts[0].act_fn @@ -502,9 +555,10 @@ def moe( topk_indices: torch.Tensor, topk_weights: torch.Tensor, ): - bs, seq_len, _ = hidden_states.shape + seq_len, _ = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + gate_proj = self.all_gate_proj[topk_indices.flatten()] up_proj = self.all_up_proj[topk_indices.flatten()] down_proj = self.all_down_proj[topk_indices.flatten()] @@ -515,13 +569,13 @@ def moe( up_out = torch.bmm(expert_in, up_proj) hidden = self.act_fn(gate_out) * up_out expert_output = torch.bmm(hidden, down_proj) - experts_out = expert_output.view(bs * seq_len, self.gate.top_k, self.config.hidden_size) + experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) experts_out = experts_out * topk_weights.unsqueeze(-1) # final_hidden_states = experts_out.sum(dim=1) final_hidden_states = torch.einsum("abc->ac", experts_out) return final_hidden_states.type(hidden_states.dtype) - + def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape @@ -530,32 +584,24 @@ def forward(self, hidden_states): hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) - expert_mask = expert_mask.permute(2, 0, 1) - # breakpoint() - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - # token_indices, weight_indices = torch.where(mask) - # if token_indices.numel() > 0: - # if torch.sum(mask).item() > 0: - # expert_weights = topk_weights[token_indices, weight_indices] - # expert_input = hidden_states[token_indices] - # expert_output = expert(expert_input) - expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None]) - # weighted_output = expert_output * expert_weights.unsqueeze(-1) - # final_hidden_states.index_add_(0, token_indices, weighted_output) - expert_output = torch.where( - (topk_weights * mask).sum(1).to(torch.bool)[:, None], - expert_output, - torch.tensor(0.0), - ) - final_hidden_states = final_hidden_states + expert_output - return final_hidden_states.type(hidden_states.dtype) + return hidden_states + self.shared_experts(residuals) + + # def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + # final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + # expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + # expert_mask = expert_mask.permute(2, 0, 1) + + # for expert_idx in range(len(self.experts)): + # expert = self.experts[expert_idx] + # mask = expert_mask[expert_idx] + # expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None]) + # expert_output = torch.where( + # (topk_weights * mask).sum(1).to(torch.bool)[:, None], + # expert_output, + # torch.tensor(0.0), + # ) + # final_hidden_states = final_hidden_states + expert_output + # return final_hidden_states.type(hidden_states.dtype) class QEffDeepseekV3DecoderLayer(nn.Module): @@ -567,29 +613,31 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, enable_mla: Optional[bool] = False, - enable_mla_absorption: Optional[bool] = False, + mla_absorption: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) if enable_mla: - hidden_states, self_attn_weights, present_key_value, vs = self.self_attn.fused_forward( + hidden_states, self_attn_weights, present_compressed_kvs, vs = self.self_attn.fused_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings, past_key_value=past_key_value, + compressed_kvs=compressed_kvs, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - enable_mla_absorption=enable_mla_absorption, + mla_absorption=mla_absorption, **kwargs, ) @@ -606,7 +654,7 @@ def forward( cache_position=cache_position, **kwargs, ) - import ipdb; ipdb.set_trace() + # import ipdb; ipdb.set_trace() hidden_states = residual + hidden_states residual = hidden_states @@ -618,7 +666,10 @@ def forward( if output_attentions: outputs += (self_attn_weights,) if use_cache: - outputs += (present_key_value,) + if enable_mla: + outputs += (present_compressed_kvs,) + else: + outputs += (present_key_value,) return outputs @@ -646,7 +697,6 @@ def __qeff_init__(self): base=self.config.rope_theta, **kwargs, ) - # import ipdb; ipdb.set_trace() def forward( self, @@ -680,11 +730,13 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) enable_mla = getattr(self, "enable_mla", False) + if enable_mla: - compressed_kvs = QEffDynamicCache.from_legacy_cache(compressed_kvs) - - - + compressed_kvs = QEffDynamicCompressedKVCache.from_legacy_cache(compressed_kvs) + target_len = compressed_kvs.layers[0].ckv.shape[-2] + else: + target_len = past_key_values[0][0].shape[2] + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -694,7 +746,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values[0][0].shape[2]) + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_len) hidden_states = inputs_embeds position_embeddings = None @@ -732,7 +784,7 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, enable_mla = getattr(self, "enable_mla", False), - enable_mla_absorption = getattr(self, "enable_mla_absorption", False), + mla_absorption = getattr(self, "mla_absorption_config", None), **kwargs, ) @@ -758,69 +810,6 @@ def forward( attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - # causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - # attention_mask, - # sequence_length=sequence_length, - # target_length=target_length, - # dtype=dtype, - # device=device, - # cache_position=cache_position, - # batch_size=input_tensor.shape[0], - # ) - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - class QEffDeepseekV3ForCausalLM(nn.Module): """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" @@ -843,6 +832,7 @@ def forward( num_logits_to_keep: int = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: + before_keys = self.state_dict().keys() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -882,7 +872,8 @@ def forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + after_keys = self.state_dict().keys() + import ipdb; ipdb.set_trace() return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c2a63fefc0..be9a307323 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -9,7 +9,7 @@ import warnings from pathlib import Path from time import perf_counter -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -3000,9 +3000,9 @@ def export( if enable_mla:=kwargs.get('enable_mla', False): self.hash_params['enable_mla'] = enable_mla setattr(self.model.model, "enable_mla", enable_mla) - if enable_mla_absorption:=kwargs.get('enable_mla_absorption', False): - self.hash_params['enable_mla_absorption'] = enable_mla_absorption - setattr(self.model.model, "enable_mla_absorption", enable_mla_absorption) + if mla_absorption_config:=kwargs.get('mla_absorption_config', None): + self.hash_params['mla_absorption_config'] = mla_absorption_config + setattr(self.model.model, "mla_absorption_config", mla_absorption_config) # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: @@ -3124,16 +3124,14 @@ def export( qaic_config=self.model.qaic_config, ) if enable_mla: - [example_inputs.pop(k) for k in example_inputs.keys() if "past" in k] - [dynamic_axes.pop(k) for k in dynamic_axes.keys() if "past" in k] + example_inputs = {k:v for k,v in example_inputs.items() if "past" not in k} + dynamic_axes = {k:v for k,v in dynamic_axes.items() if "past" not in k} output_names = [v for v in output_names if "past" not in v] - example_inputs['compressed_kv'] = [] + example_inputs['compressed_kvs'] = [] for i in range(self.num_layers): - example_inputs['compressed_kv'][i] = torch.zeros((bs, seq_len, self.model.config.kv_lora_rank + self.model.config.qk_rope_head_dim), dtype=torch.float32) - dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "seq_len"} - output_names.append(f"compressed_kv.{i}_RetainedState") - - + example_inputs['compressed_kvs'].append(torch.zeros((bs, seq_len, self.model.config.kv_lora_rank + self.model.config.qk_rope_head_dim), dtype=torch.float32)) + dynamic_axes[f"compressed_kvs.{i}"] = {0: "batch_size", 1: "seq_len"} + output_names.append(f"compressed_kvs.{i}_RetainedState") return self._export( example_inputs, @@ -3289,7 +3287,7 @@ def compile( enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, enable_mla: Optional[bool] = False, - enable_mla_absorption: Optional[bool] = False, + mla_absorption_config: Optional[Dict[str, bool]] = False, **compiler_options, ) -> str: """ @@ -3371,7 +3369,7 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ - if enable_mla_absorption and not enable_mla: + if mla_absorption_config and not enable_mla: logger.warning("enable_mla_fusion will be ignored as enable_mla is set to False") if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: logger.warning( @@ -3499,11 +3497,16 @@ def compile( # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" custom_io = {} + if not enable_mla: + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + else: + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + custom_io[f'compressed_kvs.{i}{suffix}'] = kv_cache_dtype - for suffix in ["", "_RetainedState"]: - for i in range(self.num_layers): - for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -3523,7 +3526,7 @@ def compile( enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, enable_mla = enable_mla, - enable_mla_absorption = enable_mla_absorption, + mla_absorption_config = mla_absorption_config, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index bc2d6a0707..5af8d6bc4c 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -956,7 +956,8 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, "DeepseekV3MoE": { "forward": QEffDeepseekV3MoE.forward, - "moe": QEffDeepseekV3MoE.moe + "moe": QEffDeepseekV3MoE.moe, + "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, }, "DeepseekV3Attention":{ "forward": QEffDeepseekV3Attention.forward, diff --git a/examples/export_kimik2.py b/examples/export_kimik2.py new file mode 100644 index 0000000000..850f0de920 --- /dev/null +++ b/examples/export_kimik2.py @@ -0,0 +1,18 @@ +import torch +import torch +torch.set_printoptions( + precision=4, + edgeitems=2, + threshold=50, # max number of elements printed + linewidth=120 +) +from transformers import AutoModelForCausalLM +from QEfficient import QEFFAutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained("moonshotai/Kimi-K2-Thinking", num_hidden_layers=2, trust_remote_code=True, torch_dtype=torch.float32) +qeff_model = QEFFAutoModelForCausalLM(model) + +onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable":True, "online": False}) +print(onnx_path) +qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, mla_absorption_config={"enable":True, "online": False}, + mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=4, num_cores=16) +print(onnx_path) \ No newline at end of file From 4e025c6ccba9fa3d93110a969a550897314eba8a Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 10 Feb 2026 09:49:30 +0000 Subject: [PATCH 06/51] Add prefill only moe changes from kimik2 branch Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 2 +- .../deepseek_v3/modeling_deepseek_qeff.py | 84 ++++++++++++++++++- .../transformers/models/pytorch_transforms.py | 11 ++- examples/export_kimik2.py | 12 ++- examples/run_kimik2.py | 72 +++++++++------- examples/run_orig_kimi_k2.py | 27 ++++++ 6 files changed, 171 insertions(+), 37 deletions(-) create mode 100644 examples/run_orig_kimi_k2.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 1300f70558..3530228c81 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -275,7 +275,7 @@ def _export( else: input_names.append(param) - import ipdb; ipdb.set_trace() + #import ipdb; ipdb.set_trace() try: torch.onnx.export( self.model, diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 1c67ba11aa..85c6940767 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -604,6 +604,88 @@ def forward(self, hidden_states): # return final_hidden_states.type(hidden_states.dtype) +class QEffPrefillOnlyDeepseekV3MoE(nn.Module): + + def __qeff_init__( + self, + ): + self.all_gate_proj = torch.nn.Parameter( + torch.cat([exp.gate_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + ) + self.all_up_proj = torch.nn.Parameter( + torch.cat([exp.up_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + ) + self.all_down_proj = torch.nn.Parameter( + torch.cat([exp.down_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) + ) + self.act_fn = self.experts[0].act_fn + + def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + for expert_idx in range(num_experts): + expert = self.experts[expert_idx] + gate_out = expert.gate_proj(hidden_states) + up_out = expert.up_proj(hidden_states) + hidden = expert.act_fn(gate_out) * up_out + expert_output = expert.down_proj(hidden) + current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) + final_hidden_states += current_hidden_states + + return final_hidden_states.type(hidden_states.dtype) + + def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + """ + Forward pass of MoE block. + """ + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) + mask.scatter_(1, topk_indices, topk_weights) + if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: + hidden_states = self.moe_blocked_weights_forward( + hidden_states, topk_weights, mask, self.config.n_routed_experts + ).view(*orig_shape) + elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: + hidden_states = self.moe_blocked_forward( + hidden_states, topk_weights, mask, self.config.n_routed_experts + ).view(*orig_shape) + else: + hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) + + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + class QEffDeepseekV3DecoderLayer(nn.Module): """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" @@ -873,7 +955,7 @@ def forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output after_keys = self.state_dict().keys() - import ipdb; ipdb.set_trace() + #import ipdb; ipdb.set_trace() return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5af8d6bc4c..634bc3e329 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -249,7 +249,7 @@ from QEfficient.transformers.models.deberta_v2.modeling_deberta_v2 import ( QEffDisentangledSelfAttention, ) -from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import QEffDeepseekV3Attention, QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3MoE, QEffDeepseekV3Model +from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import QEffDeepseekV3Attention, QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3MoE, QEffDeepseekV3Model, QEffPrefillOnlyDeepseekV3MoE from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconDecoderLayer, @@ -732,6 +732,7 @@ class PrefillOnlyTransform(ModuleMappingTransform): QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE, } @@ -741,10 +742,14 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, +<<<<<<< HEAD # Qwen3Moe QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, # Qwen3 VL Moe QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, +======= + QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE, +>>>>>>> ba3218c (Add prefill only moe changes from kimik2 branch) } @@ -756,8 +761,12 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, +<<<<<<< HEAD # Qwen3Moe QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, +======= + QEffPrefillOnlyDeepseekV3MoE: QEffDeepseekV3MoE, +>>>>>>> ba3218c (Add prefill only moe changes from kimik2 branch) } diff --git a/examples/export_kimik2.py b/examples/export_kimik2.py index 850f0de920..fbc8c106e7 100644 --- a/examples/export_kimik2.py +++ b/examples/export_kimik2.py @@ -6,13 +6,19 @@ threshold=50, # max number of elements printed linewidth=120 ) -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained("moonshotai/Kimi-K2-Thinking", num_hidden_layers=2, trust_remote_code=True, torch_dtype=torch.float32) + +model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) + qeff_model = QEFFAutoModelForCausalLM(model) onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable":True, "online": False}) print(onnx_path) qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, mla_absorption_config={"enable":True, "online": False}, mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=4, num_cores=16) -print(onnx_path) \ No newline at end of file +print(qpc_path) + +prompts = "Once upon a time," +qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 64812371eb..c7f57294e7 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -4,29 +4,26 @@ from QEfficient import QEFFAutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained("moonshotai/Kimi-K2-Thinking", num_hidden_layers=1, trust_remote_code=True, torch_dtype=torch.float32) -import ipdb; ipdb.set_trace() -#model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float16, num_hidden_layers=2, trust_remote_code=True) -# model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -PREFILL_SEQ_LEN=128 +PREFILL_SEQ_LEN=32 +CTX_LEN = 128 +generation_len = 10 +generated_ids = [] - -prompts = "Once upon a time," -inputs = tokenizer(prompts, return_tensors="pt", padding=True) +prompt = "Once upon a time," +inputs = tokenizer(prompt, return_tensors="pt", padding=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len - # with torch.no_grad(): # out = model(**inputs) # predictions = torch.argmax(out.logits, dim=-1) - qeff_model = QEFFAutoModelForCausalLM(model) -import ipdb; ipdb.set_trace() -inputs = tokenizer(prompts, return_tensors="np", padding="max_length", max_length=padded_len) + +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs.pop("token_type_ids", None) inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} @@ -41,22 +38,35 @@ past_key_values.append(pkv) inputs["past_key_values"] = past_key_values -# qeff_model.compile(prefill_seq_len=1, ctx_len=1024, mxfp6_matmul=True, num_devices=1) -qeff_out = qeff_model.model(**inputs) - -# assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 -# qeff_model = QEFFAutoModelForCausalLM(model) -# qeff_model.compile( -# prefill_seq_len=1, -# num_devices=1, -# use_onnx_subfunctions=True, -# ctx_len=8192, -# mxfp6_matmul=True, -# # mxint8_kv_cache=True, -# mos=1, -# aic_enable_depth_first=True, -# num_cores=16, -# offload_pt_weights=True, -# ) -# tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.7") -# qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) +prefill_qeff_out = qeff_model.model(**inputs) + +#assert (prefill_qeff_out.logits - prefill_out.logits[:, -1, :]).abs().max() < 1e-4 + +position_ids = inputs["position_ids"] +qeff_out = prefill_qeff_out +qeff_generated_ids = [] +for _ in range(1, generation_len): + next_token_id = qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + qeff_generated_ids.append(next_token_id) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "past_key_values": qeff_out["past_key_values"], + } + qeff_out = qeff_model.model(**decode_inputs) + +qeff_generated_ids.append(qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) +qeff_generated_ids = np.concatenate(qeff_generated_ids, axis=1) +predicted_string = tokenizer.batch_decode(qeff_generated_ids, skip_special_tokens=True) +print("QEFF Transformed Model Outputs (Torch CPU): \n") +print("Prompt:", repr(prompt)) +print("Completion:", repr(predicted_string)) + +#assert (qeff_generated_ids == generated_ids).all() + + +onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable":True, "online": False}) +qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, mla_absorption_config={"enable":True, "online": False}, mxfp6_matmul=False, mxint8_kv_cache=False, num_devices=4, num_cores=16) + +qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/run_orig_kimi_k2.py b/examples/run_orig_kimi_k2.py new file mode 100644 index 0000000000..695377ca06 --- /dev/null +++ b/examples/run_orig_kimi_k2.py @@ -0,0 +1,27 @@ +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + +from QEfficient import QEFFAutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) + +prompt = "Once upon a time," +inputs = tokenizer(prompt, return_tensors="pt").to(model.device) +with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=10, + do_sample=False, + use_cache=False, + ) + +response = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(response) + +""" +Original Pytorch, kimi-k2 thinking: +Prompt: Once upon a time, +Completion : ?? branchesrupt??? flushedakislottery rehearsallesi +""" From 2c5358d95a9f71bc694ed07d47c442fb4539bf15 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 16 Feb 2026 21:06:06 +0000 Subject: [PATCH 07/51] Change Cache for compressed KV and k_rope Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 4 +- .../generation/text_generation_inference.py | 6 + QEfficient/transformers/cache_utils.py | 82 +++++++++++++- .../deepseek_v3/modeling_deepseek_qeff.py | 103 ++++++------------ .../transformers/models/modeling_auto.py | 28 ++++- examples/compare.py | 94 ++++++++++++++++ examples/run_kimik2.py | 44 +++++--- 7 files changed, 271 insertions(+), 90 deletions(-) create mode 100644 examples/compare.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3530228c81..ca2538d6fc 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -271,7 +271,9 @@ def _export( ) elif param == "compressed_kvs": for i in range(len(example_inputs["compressed_kvs"])): - input_names.extend([f"compressed_kvs.{i}",]) + #input_names.extend([f"compressed_kvs.{i}",]) + input_names.extend([f"compressed_kv.{i}",]) + input_names.extend([f"k_pe.{i}",]) else: input_names.append(param) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index aa929981b4..9277fb6915 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -499,6 +499,12 @@ def __init__( self._session.skip_buffers( [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] ) + self._session.skip_buffers( + [x for x in self._session.input_names + self._session.output_names if x.startswith("compressed_")] + ) + self._session.skip_buffers( + [x for x in self._session.input_names + self._session.output_names if x.endswith("_pe")] + ) def _set_tokenizer_params(self): """ diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 9fc3936208..921523a806 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -342,7 +342,7 @@ def update3D( return k_out, v_out - +''' class QEffDynamicCompressedKVLayer: def __init__(self, ckv): self.ckv = ckv @@ -369,7 +369,6 @@ def update(self, compressed_kv, cache_kwargs): return ckv_out - class QEffDynamicCompressedKVCache: def __init__(self,): self.layers=[] @@ -394,7 +393,86 @@ def to_legacy_cache(self, ): for layer in self.layers: legacy_cache += (layer.ckv,) return legacy_cache +''' + +class QEffDynamicCompressedKVRopeLayer: + def __init__(self, ckv, k_pe): + self.ckv = ckv + self.k_pe = k_pe + def update_ckv(self, compressed_kv, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later + + self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv) + + ckv_out = self.ckv + ctx_len = ckv_out.shape[1] + ctx_indices = torch.arange(ctx_len)[None, ...] + gather_limit = position_ids.max(1, keepdim=True).values + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices) + ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) + return ckv_out + + def update_k_pe(self, k_pe_cache, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later + + self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) + + k_pe_out = self.k_pe + ctx_len = k_pe_out.shape[-2] + ctx_indices = torch.arange(ctx_len)[None, ...] + gather_limit = position_ids.max(1, keepdim=True).values + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + k_pe_out = CtxGatherFunc.apply(k_pe_out, ctx_indices, ctx_len) + k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out) + return k_pe_out + + +class QEffDynamicCompressedKVRopeCache: + def __init__(self,): + self.layers=[] + + def add_new(self, ckv, k_pe, layer_idx): + self.layers.append(QEffDynamicCompressedKVRopeLayer(ckv, k_pe)) + + def update_ckv(self, ckv, layer_idx, cache_kwargs): + return self.layers[layer_idx].update_ckv(ckv, cache_kwargs) + + def update_k_pe(self, k_pe, layer_idx, cache_kwargs): + return self.layers[layer_idx].update_k_pe(k_pe, cache_kwargs) + + @classmethod + def from_legacy_cache(cls, past_key_values): + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + ckv, k_pe = past_key_values[layer_idx] + cache.add_new(ckv, k_pe, layer_idx) + return cache + + def to_legacy_cache(self, ): + legacy_cache = () + for layer in self.layers: + x = (layer.ckv, layer.k_pe) + legacy_cache += (x,) + return legacy_cache + + class QEffDynamicCache(Cache): diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 85c6940767..611b9096ee 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -25,7 +25,7 @@ # logger, # ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVCache +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache def rotate_half(x): @@ -370,6 +370,7 @@ def __qeff_init__(self,): # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) + def fused_forward( self, hidden_states: torch.Tensor, @@ -386,39 +387,35 @@ def fused_forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - # import ipdb; ipdb.set_trace() - if compressed_kvs is not None: - cache_kwargs = { - "position_ids": position_ids, - "batch_index": batch_index, - } - compressed_kv = compressed_kvs.update(compressed_kv, self.layer_idx, cache_kwargs) + #compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states) + #k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - # import ipdb; ipdb.set_trace() q_pe = torch.bmm(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - - - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) kva = self.kv_a_layernorm(compressed_kv) k_nope = torch.bmm(kva, self.k_up) - k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + k_nope = k_nope.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) value_states = torch.bmm(kva, self.v_up) - value_states = value_states.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, seq_len=32*1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) absorb_online = mla_absorption.get("online", False) @@ -428,10 +425,8 @@ def fused_forward( if enable_absorption: if absorb_online: print("online absorption") - atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)), kva.transpose(1, 2).unsqueeze(1)) else: - print("using fused qk") atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1)) else: @@ -443,23 +438,16 @@ def fused_forward( if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - # ctx_len = past_key_value[self.layer_idx][0].shape[2] - # ctx_indices = torch.arange(ctx_len)[None, None, ...] - # gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - # invalid_mask = ctx_indices > gather_limit - # value_states = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), value_states) - # attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - # if not output_attentions: - # attn_weights = None - return attn_output, attn_weights, compressed_kvs, value_states + def forward( self, hidden_states: torch.Tensor, @@ -499,13 +487,12 @@ def forward( ) cos, sin = self.rotary_emb(value_states, seq_len=32*1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = torch.cat((q_nope, q_pe), -1) k_pe_new = k_pe.repeat(1,self.num_heads,1,1) key_states = torch.cat((k_nope, k_pe_new), -1) - # import ipdb; ipdb.set_trace() + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -522,7 +509,6 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value, value_states @@ -579,12 +565,11 @@ def moe( def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape - # breakpoint() topk_indices, topk_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states + self.shared_experts(residuals) + return hidden_states # def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): # final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) @@ -736,7 +721,6 @@ def forward( cache_position=cache_position, **kwargs, ) - # import ipdb; ipdb.set_trace() hidden_states = residual + hidden_states residual = hidden_states @@ -814,7 +798,7 @@ def forward( enable_mla = getattr(self, "enable_mla", False) if enable_mla: - compressed_kvs = QEffDynamicCompressedKVCache.from_legacy_cache(compressed_kvs) + compressed_kvs = QEffDynamicCompressedKVRopeCache.from_legacy_cache(compressed_kvs) target_len = compressed_kvs.layers[0].ckv.shape[-2] else: target_len = past_key_values[0][0].shape[2] @@ -840,35 +824,21 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - batch_index, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - compressed_kvs = compressed_kvs, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - enable_mla = getattr(self, "enable_mla", False), - mla_absorption = getattr(self, "mla_absorption_config", None), - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + compressed_kvs = compressed_kvs, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + enable_mla = getattr(self, "enable_mla", False), + mla_absorption = getattr(self, "mla_absorption_config", None), + **kwargs, + ) hidden_states = layer_outputs[0] if use_cache: @@ -954,8 +924,7 @@ def forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - after_keys = self.state_dict().keys() - #import ipdb; ipdb.set_trace() + return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index be9a307323..4a1592cc09 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2686,6 +2686,14 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [] + def mla( + self, + enable_mla: Optional[bool] = False, + mla_absorption_config: Optional[Dict[str, bool]] = False, + ): + setattr(self.model.model, "enable_mla", enable_mla) + setattr(self.model.model, "mla_absorption_config", mla_absorption_config) + def prefill( self, enable: Optional[bool] = True, @@ -2992,6 +3000,8 @@ def export( # kv_cache_shape = get_padding_shape_from_config( # self.model.config, fbs if self.continuous_batching else bs, seq_len # ) + ckv_shape = (1,seq_len, 512) + k_pe_shape = (1,1, seq_len, 64) kv_cache_shape = (1, 64, seq_len, 192) kv_cache_shape_v = (1, 64, seq_len, 128) enable_chunking = kwargs.get("enable_chunking", False) @@ -3127,12 +3137,17 @@ def export( example_inputs = {k:v for k,v in example_inputs.items() if "past" not in k} dynamic_axes = {k:v for k,v in dynamic_axes.items() if "past" not in k} output_names = [v for v in output_names if "past" not in v] - example_inputs['compressed_kvs'] = [] + example_inputs['compressed_kvs'] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): - example_inputs['compressed_kvs'].append(torch.zeros((bs, seq_len, self.model.config.kv_lora_rank + self.model.config.qk_rope_head_dim), dtype=torch.float32)) - dynamic_axes[f"compressed_kvs.{i}"] = {0: "batch_size", 1: "seq_len"} - output_names.append(f"compressed_kvs.{i}_RetainedState") - + ckv = torch.zeros((bs, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) + k_pe = torch.zeros((bs, 1, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) + example_inputs['compressed_kvs'][i].append(ckv) + example_inputs['compressed_kvs'][i].append(k_pe) + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "seq_len"} + dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "seq_len"} + output_names.append(f"compressed_kv.{i}_RetainedState") + output_names.append(f"k_pe.{i}_RetainedState") + return self._export( example_inputs, output_names=output_names, @@ -3505,7 +3520,8 @@ def compile( else: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): - custom_io[f'compressed_kvs.{i}{suffix}'] = kv_cache_dtype + custom_io[f'compressed_kv.{i}{suffix}'] = kv_cache_dtype + custom_io[f'k_pe.{i}{suffix}'] = kv_cache_dtype qpc_path = self._compile( onnx_path=onnx_path, diff --git a/examples/compare.py b/examples/compare.py new file mode 100644 index 0000000000..5113670559 --- /dev/null +++ b/examples/compare.py @@ -0,0 +1,94 @@ +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + +from QEfficient import QEFFAutoModelForCausalLM + +prompt = "Once upon a time," + +model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) +PREFILL_SEQ_LEN=128 +CTX_LEN = 128 +generation_len = 5 +generated_ids = [] + +inputs = tokenizer(prompt, return_tensors="pt", padding=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +with torch.no_grad(): + out = model(**inputs) + predictions = torch.argmax(out.logits, dim=-1) + + +qeff_model_no_mla = QEFFAutoModelForCausalLM(model) + +qeff_model_mla = QEFFAutoModelForCausalLM(model) + +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +compressed_kvs = [] +for i in range(model.config.num_hidden_layers): + cache_len = 128 + pad_shape_k = (1, 64, cache_len, 192) + pad_shape_v = (1, 64, cache_len, 128) + past_key = torch.zeros((pad_shape_k), dtype=torch.float32) + past_value = torch.zeros((pad_shape_v), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + compressed_kvs.append(torch.zeros(1, cache_len, 576)) +inputs['compressed_kvs'] = compressed_kvs + +prefill_qeff_out_mla = qeff_model_mla.model(**inputs) + +inputs.pop("compressed_kvs") +inputs["past_key_values"] = past_key_values +prefill_qeff_out_no_mla = qeff_model_no_mla.model(**inputs) +breakpoint() +assert (prefill_qeff_out_mla.logits - out.logits[:, -1, :]).abs().max() < 1e-4 +assert (prefill_qeff_out_no_mla.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + +position_ids = inputs["position_ids"] +qeff_out_mla = prefill_qeff_out_mla +qeff_out_no_mla = prefill_qeff_out_no_mla +qeff_mla_generated_ids = [] +qeff_no_mla_generated_ids = [] +for _ in range(1, generation_len): + next_token_id_mla = qeff_out_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + next_token_id_no_mla = qeff_out_no_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + qeff_mla_generated_ids.append(next_token_id_mla) + qeff_no_mla_generated_ids.append(next_token_id_no_mla) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "compressed_kvs": qeff_out_mla["past_key_values"], + } + qeff_out_mla = qeff_model_mla.model(**decode_inputs) + + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "past_key_values": qeff_out_no_mla["past_key_values"], + } + qeff_out_no_mla = qeff_model_no_mla.model(**decode_inputs) + breakpoint() + +qeff_mla_generated_ids.append(qeff_out_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) +qeff_mla_generated_ids = np.concatenate(qeff_mla_generated_ids, axis=1) +predicted_string = tokenizer.batch_decode(qeff_mla_generated_ids, skip_special_tokens=True) +print("QEFF Transformed Model Outputs (Torch CPU): \n") +print("Prompt:", repr(prompt)) +print("Completion:", repr(predicted_string)) + +qeff_no_mla_generated_ids.append(qeff_out_no_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) +qeff_no_mla_generated_ids = np.concatenate(qeff_no_mla_generated_ids, axis=1) +predicted_string = tokenizer.batch_decode(qeff_no_mla_generated_ids, skip_special_tokens=True) +print("QEFF Transformed Model Outputs (Torch CPU): \n") +print("Prompt:", repr(prompt)) +print("Completion:", repr(predicted_string)) diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index c7f57294e7..b4a5f3d8e0 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -4,6 +4,8 @@ from QEfficient import QEFFAutoModelForCausalLM +prompt = "Once upon a time," + model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) PREFILL_SEQ_LEN=32 @@ -11,36 +13,50 @@ generation_len = 10 generated_ids = [] -prompt = "Once upon a time," inputs = tokenizer(prompt, return_tensors="pt", padding=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len -# with torch.no_grad(): -# out = model(**inputs) -# predictions = torch.argmax(out.logits, dim=-1) +with torch.no_grad(): + out = model(**inputs) + predictions = torch.argmax(out.logits, dim=-1) qeff_model = QEFFAutoModelForCausalLM(model) +qeff_model.mla(enable_mla=True, mla_absorption_config={"enable":False, "online": False}) + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs.pop("token_type_ids", None) inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + +cache_len = 128 +pad_shape_k = (1, 64, cache_len, 192) +pad_shape_v = (1, 64, cache_len, 128) +pad_shape_ckv = (1, cache_len, 512) +pad_shape_k_pe = (1, 1, cache_len, 64) + past_key_values = [] +compressed_kvs = [] + for i in range(model.config.num_hidden_layers): - cache_len = 128 - pad_shape_k = (1, 64, cache_len, 192) - pad_shape_v = (1, 64, cache_len, 128) past_key = torch.zeros((pad_shape_k), dtype=torch.float32) past_value = torch.zeros((pad_shape_v), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) -inputs["past_key_values"] = past_key_values + + ckv = torch.zeros((pad_shape_ckv), dtype=torch.float32) + k_pe = torch.zeros((pad_shape_k_pe), dtype=torch.float32) + x = (ckv, k_pe) + compressed_kvs.append(x) + + +inputs["compressed_kvs"] = compressed_kvs prefill_qeff_out = qeff_model.model(**inputs) -#assert (prefill_qeff_out.logits - prefill_out.logits[:, -1, :]).abs().max() < 1e-4 +assert (prefill_qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 position_ids = inputs["position_ids"] qeff_out = prefill_qeff_out @@ -52,7 +68,7 @@ decode_inputs = { "input_ids": next_token_id, "position_ids": position_ids, - "past_key_values": qeff_out["past_key_values"], + "compressed_kvs": qeff_out["past_key_values"], } qeff_out = qeff_model.model(**decode_inputs) @@ -63,10 +79,10 @@ print("Prompt:", repr(prompt)) print("Completion:", repr(predicted_string)) -#assert (qeff_generated_ids == generated_ids).all() - -onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable":True, "online": False}) -qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, mla_absorption_config={"enable":True, "online": False}, mxfp6_matmul=False, mxint8_kv_cache=False, num_devices=4, num_cores=16) +onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True)#, mla_absorption_config={"enable":True, "online": False}) +qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, #mla_absorption_config={"enable":True, "online": False}, +mxfp6_matmul=False, mxint8_kv_cache=False, num_devices=2, num_cores=16) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) + From 30c57e9fbb18b861fb2d786176c0de562e3854e4 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 18 Feb 2026 19:16:05 +0000 Subject: [PATCH 08/51] fix dynamic axis and output mismatch Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 16 +- .../generation/text_generation_inference.py | 3 +- QEfficient/transformers/cache_utils.py | 88 +---- .../models/deepseek_v3/modeling_deepseek.py | 364 +++++------------- .../deepseek_v3/modeling_deepseek_qeff.py | 255 +++--------- .../transformers/models/modeling_auto.py | 41 +- .../transformers/models/pytorch_transforms.py | 28 +- examples/compare.py | 94 ----- examples/export_kimik2.py | 30 +- .../causallm/example_pytorch_transforms.py | 12 +- examples/run_kimik2.py | 33 +- examples/run_orig_kimi_k2.py | 12 +- 12 files changed, 280 insertions(+), 696 deletions(-) delete mode 100644 examples/compare.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ca2538d6fc..b6da4dcf37 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -271,13 +271,21 @@ def _export( ) elif param == "compressed_kvs": for i in range(len(example_inputs["compressed_kvs"])): - #input_names.extend([f"compressed_kvs.{i}",]) - input_names.extend([f"compressed_kv.{i}",]) - input_names.extend([f"k_pe.{i}",]) + # input_names.extend([f"compressed_kvs.{i}",]) + input_names.extend( + [ + f"compressed_kv.{i}", + ] + ) + input_names.extend( + [ + f"k_pe.{i}", + ] + ) else: input_names.append(param) - #import ipdb; ipdb.set_trace() + # import ipdb; ipdb.set_trace() try: torch.onnx.export( self.model, diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 9277fb6915..991573e811 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -503,7 +503,7 @@ def __init__( [x for x in self._session.input_names + self._session.output_names if x.startswith("compressed_")] ) self._session.skip_buffers( - [x for x in self._session.input_names + self._session.output_names if x.endswith("_pe")] + [x for x in self._session.input_names + self._session.output_names if x.startswith("k_pe")] ) def _set_tokenizer_params(self): @@ -846,6 +846,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i ] if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + outputs = self._session.run(chunk_inputs) if self._write_io_dir is not None: diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 921523a806..e3d88f706f 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -342,70 +342,18 @@ def update3D( return k_out, v_out -''' -class QEffDynamicCompressedKVLayer: - def __init__(self, ckv): - self.ckv = ckv - - def update(self, compressed_kv, cache_kwargs): - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later - - self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv) - - ckv_out = self.ckv - ctx_len = ckv_out.shape[1] - ctx_indices = torch.arange(ctx_len)[None, ...] - gather_limit = position_ids.max(1, keepdim=True).values - invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices) - ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) - return ckv_out - - -class QEffDynamicCompressedKVCache: - def __init__(self,): - self.layers=[] - - def add_new(self, ckv, layer_idx): - self.layers.append(QEffDynamicCompressedKVLayer(ckv)) - - def update(self, ckv, layer_idx, cache_kwargs): - return self.layers[layer_idx].update(ckv, cache_kwargs) - - @classmethod - def from_legacy_cache(cls, past_key_values): - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - ckv = past_key_values[layer_idx] - cache.add_new(ckv, layer_idx) - return cache - - def to_legacy_cache(self, ): - legacy_cache = () - for layer in self.layers: - legacy_cache += (layer.ckv,) - return legacy_cache -''' class QEffDynamicCompressedKVRopeLayer: def __init__(self, ckv, k_pe): self.ckv = ckv self.k_pe = k_pe - + def update_ckv(self, compressed_kv, cache_kwargs): position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later - + self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv) - + ckv_out = self.ckv ctx_len = ckv_out.shape[1] ctx_indices = torch.arange(ctx_len)[None, ...] @@ -416,15 +364,15 @@ def update_ckv(self, compressed_kv, cache_kwargs): else: invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - + ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices) ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) return ckv_out - + def update_k_pe(self, k_pe_cache, cache_kwargs): position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later - + self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) k_pe_out = self.k_pe @@ -437,25 +385,27 @@ def update_k_pe(self, k_pe_cache, cache_kwargs): else: invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - + k_pe_out = CtxGatherFunc.apply(k_pe_out, ctx_indices, ctx_len) k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out) return k_pe_out - - + + class QEffDynamicCompressedKVRopeCache: - def __init__(self,): - self.layers=[] - + def __init__( + self, + ): + self.layers = [] + def add_new(self, ckv, k_pe, layer_idx): self.layers.append(QEffDynamicCompressedKVRopeLayer(ckv, k_pe)) - + def update_ckv(self, ckv, layer_idx, cache_kwargs): return self.layers[layer_idx].update_ckv(ckv, cache_kwargs) def update_k_pe(self, k_pe, layer_idx, cache_kwargs): return self.layers[layer_idx].update_k_pe(k_pe, cache_kwargs) - + @classmethod def from_legacy_cache(cls, past_key_values): cache = cls() @@ -465,15 +415,15 @@ def from_legacy_cache(cls, past_key_values): cache.add_new(ckv, k_pe, layer_idx) return cache - def to_legacy_cache(self, ): + def to_legacy_cache( + self, + ): legacy_cache = () for layer in self.layers: x = (layer.ckv, layer.k_pe) legacy_cache += (x,) return legacy_cache - - class QEffDynamicCache(Cache): """ diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 1192a0063d..5eff081888 100755 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -17,22 +17,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch DeepSeek model.""" +"""PyTorch DeepSeek model.""" + import math import warnings from typing import List, Optional, Tuple, Union +import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) from transformers.modeling_outputs import ( @@ -54,9 +54,8 @@ replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available + from .configuration_deepseek import DeepseekV3Config -import torch.distributed as dist -import numpy as np if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -81,9 +80,7 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) - ) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -118,9 +115,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -133,9 +128,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(t.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -171,9 +164,7 @@ def __init__( def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) @@ -203,17 +194,12 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len > self.max_position_embeddings: base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1) + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -223,24 +209,14 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 -): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) # Find dim range bounds based on rotations -def yarn_find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 -): - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case @@ -260,7 +236,6 @@ def yarn_linear_ramp_mask(min, max, dim): class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): - def __init__( self, dim, @@ -286,14 +261,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len dim = self.dim - freq_extra = 1.0 / ( - self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) freq_inter = 1.0 / ( - self.scaling_factor - * self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) low, high = yarn_find_correction_range( @@ -303,9 +273,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.base, self.original_max_position_embeddings, ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( - device=device, dtype=torch.float32 - ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -319,12 +287,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): ) emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False - ) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -376,9 +340,7 @@ def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = ( - config.intermediate_size if intermediate_size is None else intermediate_size - ) + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -406,13 +368,9 @@ def __init__(self, config): # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.hidden_size - self.weight = nn.Parameter( - torch.empty((self.n_routed_experts, self.gating_dim)) - ) + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter( - torch.empty((self.n_routed_experts)) - ) + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) self.reset_parameters() def reset_parameters(self) -> None: @@ -424,55 +382,42 @@ def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape ### compute gating score hidden_states = hidden_states.view(-1, h) - logits = F.linear( - hidden_states.type(torch.float32), self.weight.type(torch.float32), None - ) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) if self.scoring_func == "sigmoid": scores = logits.sigmoid() else: - raise NotImplementedError( - f"insupportable scoring function for MoE gating: {self.scoring_func}" - ) + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") ### select top-k experts if self.topk_method == "noaux_tc": assert not self.training scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) ) # [n, n_group] - group_idx = torch.topk( - group_scores, k=self.topk_group, dim=-1, sorted=False - )[ - 1 - ] # [n, top_k_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) - .expand( - bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group - ) + .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) .reshape(bsz * seq_len, -1) ) # [n, e] tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk( - tmp_scores, k=self.top_k, dim=-1, sorted=False - ) + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) topk_weight = scores.gather(1, topk_idx) else: - raise NotImplementedError( - f"insupportable TopK function for MoE gating: {self.topk_method}" - ) + raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor return topk_idx, topk_weight + class DeepseekV3MoE(nn.Module): """ A mixed expert module containing shared experts. @@ -491,11 +436,8 @@ def __init__(self, config): self.experts = nn.ModuleList( [ ( - DeepseekV3MLP( - config, intermediate_size=config.moe_intermediate_size - ) - if i >= self.ep_rank * self.experts_per_rank - and i < (self.ep_rank + 1) * self.experts_per_rank + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank else None ) for i in range(config.n_routed_experts) @@ -507,18 +449,14 @@ def __init__(self, config): self.ep_rank = 0 self.experts = nn.ModuleList( [ - DeepseekV3MLP( - config, intermediate_size=config.moe_intermediate_size - ) + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) for i in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP( - config=config, intermediate_size=intermediate_size - ) + self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) def forward(self, hidden_states): identity = hidden_states @@ -542,17 +480,9 @@ def moe_infer(self, x, topk_ids, topk_weight): sorted_tokens_shape = sorted_tokens.shape if self.ep_size > 1: tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty( - tokens_per_expert.shape[0] - ) + tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = ( - tokens_per_expert_group.view(self.ep_size, -1) - .sum(1) - .cpu() - .numpy() - .tolist() - ) + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() gathered_tokens = sorted_tokens.new_empty( tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] ) @@ -561,9 +491,7 @@ def moe_infer(self, x, topk_ids, topk_weight): list(gathered_tokens.split(output_splits)), list(sorted_tokens.split(input_split_sizes)), ) - tokens_per_expert_post_gather = tokens_per_expert_group.view( - self.ep_size, self.experts_per_rank - ).sum(dim=0) + tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) s = 0 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): @@ -618,9 +546,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -655,17 +581,11 @@ def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): self.is_causal = True if self.q_lora_rank is None: - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.q_head_dim, bias=False - ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) else: - self.q_a_proj = nn.Linear( - self.hidden_size, config.q_lora_rank, bias=config.attention_bias - ) + self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear( - config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False - ) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, @@ -675,8 +595,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) self.kv_b_proj = nn.Linear( config.kv_lora_rank, - self.num_heads - * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) @@ -742,11 +661,7 @@ def _init_rope(self): raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) - .transpose(1, 2) - .contiguous() - ) + return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() def forward( self, @@ -769,14 +684,10 @@ def forward( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) @@ -784,9 +695,7 @@ def forward( .transpose(1, 2) ) - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -809,13 +718,9 @@ def forward( key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - ) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -831,12 +736,8 @@ def forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): @@ -901,17 +802,13 @@ def forward( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) @@ -919,9 +816,7 @@ def forward( .transpose(1, 2) ) - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] kv_seq_len = value_states.shape[-2] @@ -944,9 +839,7 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -970,11 +863,7 @@ def forward( elif torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: - target_dtype = ( - self.q_proj.weight.dtype - if self.q_lora_rank is None - else self.q_a_proj.weight.dtype - ) + target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" @@ -998,9 +887,7 @@ def forward( if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] - attn_output = attn_output.reshape( - bsz, q_len, self.num_heads * self.v_head_dim - ).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -1053,9 +940,7 @@ def _flash_attention_forward( indices_q, cu_seq_lens, max_seq_lens, - ) = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -1073,9 +958,7 @@ def _flash_attention_forward( causal=causal, ) - attn_output = pad_input( - attn_output_unpad, indices_q, batch_size, query_length - ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( query_states, @@ -1088,9 +971,7 @@ def _flash_attention_forward( return attn_output - def _upad_input( - self, query_layer, key_layer, value_layer, attention_mask, query_length - ): + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape @@ -1120,9 +1001,7 @@ def _upad_input( else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask - ) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -1145,9 +1024,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) + self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = ( DeepseekV3MoE(config) @@ -1158,12 +1035,8 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): ) else DeepseekV3MLP(config) ) - self.input_layernorm = DeepseekV3RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = DeepseekV3RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -1174,9 +1047,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1357,14 +1228,9 @@ def __init__(self, config: DeepseekV3Config): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [ - DeepseekV3DecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1392,27 +1258,17 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: @@ -1442,11 +1298,7 @@ def forward( if self._use_flash_attention_2: # 2d mask is passed through the layers - attention_mask = ( - attention_mask - if (attention_mask is not None and 0 in attention_mask) - else None - ) + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -1493,17 +1345,9 @@ def forward( next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1543,9 +1387,7 @@ def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1584,19 +1426,11 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1661,10 +1495,7 @@ def prepare_inputs_for_generation( # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) - if ( - attention_mask is not None - and attention_mask.shape[1] > input_ids.shape[1] - ): + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. @@ -1709,10 +1540,7 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ), + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past @@ -1768,9 +1596,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, @@ -1792,22 +1618,18 @@ def forward( batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = ( - torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - ).to(logits.device) + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) else: sequence_lengths = -1 - pooled_logits = logits[ - torch.arange(batch_size, device=logits.device), sequence_lengths - ] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: @@ -1815,9 +1637,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1830,9 +1650,7 @@ def forward( loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct( - pooled_logits.view(-1, self.num_labels), labels.view(-1) - ) + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 611b9096ee..9e924f4253 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -4,10 +4,10 @@ import torch import torch.nn.functional as F from torch import nn -from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache # Assuming these are imported from the original DeepseekV3 code # from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( @@ -25,7 +25,6 @@ # logger, # ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache def rotate_half(x): @@ -35,40 +34,15 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - # Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 -): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) # Find dim range bounds based on rotations -def yarn_find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 -): - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case @@ -94,9 +68,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -109,9 +81,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(t.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -128,7 +98,7 @@ def forward(self, x, seq_len=None): self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) - + # def forward(self, x, position_ids): # seq_len = torch.max(position_ids) + 1 # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: @@ -140,9 +110,7 @@ def forward(self, x, seq_len=None): # return cos.to(x.dtype), sin.to(x.dtype) - class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): - def __init__( self, dim, @@ -168,14 +136,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len dim = self.dim - freq_extra = 1.0 / ( - self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) freq_inter = 1.0 / ( - self.scaling_factor - * self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) low, high = yarn_find_correction_range( @@ -185,9 +148,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.base, self.original_max_position_embeddings, ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( - device=device, dtype=torch.float32 - ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -201,106 +162,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): ) emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False - ) - - -class QEffDeepseekV3RotaryEmbedding(nn.Module): - """ - Adapted from DeepseekV3RotaryEmbedding with static sin/cos caches like QEffLlamaRotaryEmbedding. - """ - - def __init__(self, config, device=None): - super().__init__() - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Precompute static sin/cos caches - self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, position_ids): - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - # Use position_ids to slice the precomputed caches - cos = self.cos_cached[position_ids] * self.attention_scaling - sin = self.sin_cached[position_ids] * self.attention_scaling - return cos.to(x.dtype), sin.to(x.dtype) - - -def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - # Slice cos and sin using position_ids if they are larger (e.g., precomputed caches) - if cos.shape[-2] > q.shape[-2]: - cos = cos[:, position_ids, :] - sin = sin[:, position_ids, :] - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Adapted from DeepseekV3's apply_rotary_pos_emb for QEff compatibility with position_ids slicing.""" - # Slice cos and sin using position_ids if they are larger (e.g., precomputed caches) - if cos.shape[-2] > q.shape[-2]: - cos = cos[:, position_ids, :] - sin = sin[:, position_ids, :] - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed.to(q.dtype), k_embed.to(k.dtype) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb @@ -325,7 +188,6 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - # import ipdb; ipdb.set_trace() cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) @@ -340,37 +202,45 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed - class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" - def __qeff_init__(self,): - q_up, q_rope = self.q_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.qk_rope_head_dim).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_up = q_up.reshape(-1, self.num_heads*self.qk_nope_head_dim).unsqueeze(0) + def __qeff_init__( + self, + ): + q_up, q_rope = self.q_b_proj.weight.T.view( + -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim + ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) # self.register_buffer("q_up", q_up.detach().clone(), persistent=False) self.q_up = torch.nn.Parameter(q_up.detach().clone()) - q_rope = q_rope.reshape(-1, self.num_heads* self.qk_rope_head_dim).unsqueeze(0) + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) # self.register_buffer("q_rope", q_rope.detach().clone(), persistent=False) self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) - k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim+self.v_head_dim).split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_up = k_up.reshape(-1, self.num_heads*self.qk_nope_head_dim).unsqueeze(0) - v_up = v_up.reshape(-1, self.num_heads*self.v_head_dim).unsqueeze(0) + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) # self.register_buffer("k_up", k_up.detach().clone(), persistent=False) # self.register_buffer("v_up", v_up.detach().clone(), persistent=False) self.k_up = torch.nn.Parameter(k_up.detach().clone()) self.v_up = torch.nn.Parameter(v_up.detach().clone()) per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - per_head_k_up = self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) - # import ipdb; ipdb.set_trace() + per_head_k_up = ( + self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) + ) # self.register_buffer("per_head_q_up", per_head_q_up.detach().clone(), persistent=False) # self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False) self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) - self.per_haed_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) + self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) fusedqk = torch.bmm(per_head_q_up, per_head_k_up) # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - + # self.kv_a_proj_with_mqa_ckv = nn.Linear(self.hidden_size, self.config.kv_lora_rank, bias=self.config.attention_bias) + # self.kv_a_proj_with_mqa_k_pe = nn.Linear(self.hidden_size, self.config.qk_rope_head_dim, bias=self.config.attention_bias) + def fused_forward( self, hidden_states: torch.Tensor, @@ -389,11 +259,11 @@ def fused_forward( bsz, q_len, _ = hidden_states.size() compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - #compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states) - #k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states) + # compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states) + # k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.bmm(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) @@ -410,7 +280,7 @@ def fused_forward( value_states = torch.bmm(kva, self.v_up) value_states = value_states.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=32*1024) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: @@ -425,16 +295,21 @@ def fused_forward( if enable_absorption: if absorb_online: print("online absorption") - atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)), kva.transpose(1, 2).unsqueeze(1)) + atn = torch.matmul( + torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)), + kva.transpose(1, 2).unsqueeze(1), + ) else: print("using fused qk") - atn = torch.matmul(torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1)) + atn = torch.matmul( + torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1) + ) else: print("no absorption") atn = torch.matmul(q_nope, k_nope.transpose(2, 3)) - - atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3)) - attn_weights = (atn+atr) * self.softmax_scale + + atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3)) + attn_weights = (atn + atr) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) @@ -447,7 +322,6 @@ def fused_forward( return attn_output, attn_weights, compressed_kvs, value_states - def forward( self, hidden_states: torch.Tensor, @@ -467,14 +341,10 @@ def forward( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) @@ -482,15 +352,13 @@ def forward( .transpose(1, 2) ) - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - cos, sin = self.rotary_emb(value_states, seq_len=32*1024) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = torch.cat((q_nope, q_pe), -1) - k_pe_new = k_pe.repeat(1,self.num_heads,1,1) + k_pe_new = k_pe.repeat(1, self.num_heads, 1, 1) key_states = torch.cat((k_nope, k_pe_new), -1) if past_key_value is not None: @@ -590,7 +458,6 @@ def forward(self, hidden_states): class QEffPrefillOnlyDeepseekV3MoE(nn.Module): - def __qeff_init__( self, ): @@ -693,7 +560,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) if enable_mla: - hidden_states, self_attn_weights, present_compressed_kvs, vs = self.self_attn.fused_forward( + hidden_states, self_attn_weights, present_compressed_kvs, vs = self.self_attn.fused_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -758,7 +625,7 @@ def __qeff_init__(self): } self.rotary_emb = DeepseekV3YarnRotaryEmbedding( self.config.qk_rope_head_dim, - max_position_embeddings=32*1024, + max_position_embeddings=32 * 1024, scaling_factor=scaling_factor, base=self.config.rope_theta, **kwargs, @@ -828,15 +695,15 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - compressed_kvs = compressed_kvs, + compressed_kvs=compressed_kvs, past_key_value=past_key_values, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - enable_mla = getattr(self, "enable_mla", False), - mla_absorption = getattr(self, "mla_absorption_config", None), + enable_mla=getattr(self, "enable_mla", False), + mla_absorption=getattr(self, "mla_absorption_config", None), **kwargs, ) @@ -895,7 +762,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - compressed_kvs = compressed_kvs, + compressed_kvs=compressed_kvs, past_key_values=past_key_values, batch_index=batch_index, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4a1592cc09..257d50009e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -72,7 +72,6 @@ ) from QEfficient.utils import ( constants, - get_padding_shape_from_config, ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger @@ -3000,20 +2999,20 @@ def export( # kv_cache_shape = get_padding_shape_from_config( # self.model.config, fbs if self.continuous_batching else bs, seq_len # ) - ckv_shape = (1,seq_len, 512) - k_pe_shape = (1,1, seq_len, 64) + ckv_shape = (1, seq_len, 512) + k_pe_shape = (1, 1, seq_len, 64) kv_cache_shape = (1, 64, seq_len, 192) kv_cache_shape_v = (1, 64, seq_len, 128) enable_chunking = kwargs.get("enable_chunking", False) # TODO: HACK handle better - if enable_mla:=kwargs.get('enable_mla', False): - self.hash_params['enable_mla'] = enable_mla + if enable_mla := kwargs.get("enable_mla", False): + self.hash_params["enable_mla"] = enable_mla setattr(self.model.model, "enable_mla", enable_mla) - if mla_absorption_config:=kwargs.get('mla_absorption_config', None): - self.hash_params['mla_absorption_config'] = mla_absorption_config + if mla_absorption_config := kwargs.get("mla_absorption_config", None): + self.hash_params["mla_absorption_config"] = mla_absorption_config setattr(self.model.model, "mla_absorption_config", mla_absorption_config) - + # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: @@ -3134,20 +3133,20 @@ def export( qaic_config=self.model.qaic_config, ) if enable_mla: - example_inputs = {k:v for k,v in example_inputs.items() if "past" not in k} - dynamic_axes = {k:v for k,v in dynamic_axes.items() if "past" not in k} + example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} + dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} output_names = [v for v in output_names if "past" not in v] - example_inputs['compressed_kvs'] = [[] for _ in range(self.num_layers)] + example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): ckv = torch.zeros((bs, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) k_pe = torch.zeros((bs, 1, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) - example_inputs['compressed_kvs'][i].append(ckv) - example_inputs['compressed_kvs'][i].append(k_pe) - dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "seq_len"} - dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "seq_len"} + example_inputs["compressed_kvs"][i].append(ckv) + example_inputs["compressed_kvs"][i].append(k_pe) + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "ctx_len"} + dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} output_names.append(f"compressed_kv.{i}_RetainedState") output_names.append(f"k_pe.{i}_RetainedState") - + return self._export( example_inputs, output_names=output_names, @@ -3306,7 +3305,7 @@ def compile( **compiler_options, ) -> str: """ - + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. This method generates a ``qpc`` package. If the model has not been exported yet, @@ -3520,8 +3519,8 @@ def compile( else: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): - custom_io[f'compressed_kv.{i}{suffix}'] = kv_cache_dtype - custom_io[f'k_pe.{i}{suffix}'] = kv_cache_dtype + custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype + custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype qpc_path = self._compile( onnx_path=onnx_path, @@ -3541,8 +3540,8 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, - enable_mla = enable_mla, - mla_absorption_config = mla_absorption_config, + enable_mla=enable_mla, + mla_absorption_config=mla_absorption_config, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 634bc3e329..cf9350cdfc 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -249,7 +249,14 @@ from QEfficient.transformers.models.deberta_v2.modeling_deberta_v2 import ( QEffDisentangledSelfAttention, ) -from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import QEffDeepseekV3Attention, QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3MoE, QEffDeepseekV3Model, QEffPrefillOnlyDeepseekV3MoE +from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import ( + QEffDeepseekV3Attention, + QEffDeepseekV3DecoderLayer, + QEffDeepseekV3ForCausalLM, + QEffDeepseekV3Model, + QEffDeepseekV3MoE, + QEffPrefillOnlyDeepseekV3MoE, +) from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconDecoderLayer, @@ -953,13 +960,10 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "RMSNorm": { "forward": QEFFGrok1CustomRMSNormAIC.forward, }, - "DeepseekV3ForCausalLM":{ + "DeepseekV3ForCausalLM": { "forward": QEffDeepseekV3ForCausalLM.forward, }, - "DeepseekV3Model":{ - "forward": QEffDeepseekV3Model.forward, - "__qeff_init__": QEffDeepseekV3Model.__qeff_init__ - }, + "DeepseekV3Model": {"forward": QEffDeepseekV3Model.forward, "__qeff_init__": QEffDeepseekV3Model.__qeff_init__}, "DeepseekV3DecoderLayer": { "forward": QEffDeepseekV3DecoderLayer.forward, }, @@ -968,14 +972,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "moe": QEffDeepseekV3MoE.moe, "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, }, - "DeepseekV3Attention":{ - "forward": QEffDeepseekV3Attention.forward, - "fused_forward": QEffDeepseekV3Attention.fused_forward, - "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, + "DeepseekV3Attention": { + "forward": QEffDeepseekV3Attention.forward, + "fused_forward": QEffDeepseekV3Attention.fused_forward, + "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, - "DeepseekV3RMSNorm":{ + "DeepseekV3RMSNorm": { "forward": QEFFGrok1CustomRMSNormAIC.forward, - } + }, } diff --git a/examples/compare.py b/examples/compare.py deleted file mode 100644 index 5113670559..0000000000 --- a/examples/compare.py +++ /dev/null @@ -1,94 +0,0 @@ -import numpy as np -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - -from QEfficient import QEFFAutoModelForCausalLM - -prompt = "Once upon a time," - -model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) -tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -PREFILL_SEQ_LEN=128 -CTX_LEN = 128 -generation_len = 5 -generated_ids = [] - -inputs = tokenizer(prompt, return_tensors="pt", padding=True) -padded_len = inputs["input_ids"].shape[1] -num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float -padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len - -with torch.no_grad(): - out = model(**inputs) - predictions = torch.argmax(out.logits, dim=-1) - - -qeff_model_no_mla = QEFFAutoModelForCausalLM(model) - -qeff_model_mla = QEFFAutoModelForCausalLM(model) - -inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) -inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) -inputs.pop("token_type_ids", None) -inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -past_key_values = [] -compressed_kvs = [] -for i in range(model.config.num_hidden_layers): - cache_len = 128 - pad_shape_k = (1, 64, cache_len, 192) - pad_shape_v = (1, 64, cache_len, 128) - past_key = torch.zeros((pad_shape_k), dtype=torch.float32) - past_value = torch.zeros((pad_shape_v), dtype=torch.float32) - pkv = (past_key, past_value) - past_key_values.append(pkv) - compressed_kvs.append(torch.zeros(1, cache_len, 576)) -inputs['compressed_kvs'] = compressed_kvs - -prefill_qeff_out_mla = qeff_model_mla.model(**inputs) - -inputs.pop("compressed_kvs") -inputs["past_key_values"] = past_key_values -prefill_qeff_out_no_mla = qeff_model_no_mla.model(**inputs) -breakpoint() -assert (prefill_qeff_out_mla.logits - out.logits[:, -1, :]).abs().max() < 1e-4 -assert (prefill_qeff_out_no_mla.logits - out.logits[:, -1, :]).abs().max() < 1e-4 - -position_ids = inputs["position_ids"] -qeff_out_mla = prefill_qeff_out_mla -qeff_out_no_mla = prefill_qeff_out_no_mla -qeff_mla_generated_ids = [] -qeff_no_mla_generated_ids = [] -for _ in range(1, generation_len): - next_token_id_mla = qeff_out_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1) - next_token_id_no_mla = qeff_out_no_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1) - qeff_mla_generated_ids.append(next_token_id_mla) - qeff_no_mla_generated_ids.append(next_token_id_no_mla) - position_ids = position_ids.max(1, keepdim=True).values + 1 - decode_inputs = { - "input_ids": next_token_id, - "position_ids": position_ids, - "compressed_kvs": qeff_out_mla["past_key_values"], - } - qeff_out_mla = qeff_model_mla.model(**decode_inputs) - - decode_inputs = { - "input_ids": next_token_id, - "position_ids": position_ids, - "past_key_values": qeff_out_no_mla["past_key_values"], - } - qeff_out_no_mla = qeff_model_no_mla.model(**decode_inputs) - breakpoint() - -qeff_mla_generated_ids.append(qeff_out_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) -qeff_mla_generated_ids = np.concatenate(qeff_mla_generated_ids, axis=1) -predicted_string = tokenizer.batch_decode(qeff_mla_generated_ids, skip_special_tokens=True) -print("QEFF Transformed Model Outputs (Torch CPU): \n") -print("Prompt:", repr(prompt)) -print("Completion:", repr(predicted_string)) - -qeff_no_mla_generated_ids.append(qeff_out_no_mla["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) -qeff_no_mla_generated_ids = np.concatenate(qeff_no_mla_generated_ids, axis=1) -predicted_string = tokenizer.batch_decode(qeff_no_mla_generated_ids, skip_special_tokens=True) -print("QEFF Transformed Model Outputs (Torch CPU): \n") -print("Prompt:", repr(prompt)) -print("Completion:", repr(predicted_string)) diff --git a/examples/export_kimik2.py b/examples/export_kimik2.py index fbc8c106e7..fb8a724696 100644 --- a/examples/export_kimik2.py +++ b/examples/export_kimik2.py @@ -1,23 +1,39 @@ import torch -import torch + torch.set_printoptions( precision=4, edgeitems=2, - threshold=50, # max number of elements printed - linewidth=120 + threshold=50, # max number of elements printed + linewidth=120, ) from transformers import AutoModelForCausalLM, AutoTokenizer + from QEfficient import QEFFAutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", + torch_dtype=torch.float32, + num_hidden_layers=2, + trust_remote_code=True, +) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) qeff_model = QEFFAutoModelForCausalLM(model) -onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable":True, "online": False}) +onnx_path = qeff_model.export( + prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable": True, "online": False} +) print(onnx_path) -qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, mla_absorption_config={"enable":True, "online": False}, - mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=4, num_cores=16) +qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=128, + enable_mla=True, + mla_absorption_config={"enable": True, "online": False}, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=4, + num_cores=16, +) print(qpc_path) prompts = "Once upon a time," diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9c..503efc12dc 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index b4a5f3d8e0..0df64cceba 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -1,14 +1,19 @@ import numpy as np import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers import AutoModelForCausalLM, AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) +model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True +) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -PREFILL_SEQ_LEN=32 + + +PREFILL_SEQ_LEN = 32 CTX_LEN = 128 generation_len = 10 generated_ids = [] @@ -23,8 +28,7 @@ predictions = torch.argmax(out.logits, dim=-1) qeff_model = QEFFAutoModelForCausalLM(model) -qeff_model.mla(enable_mla=True, mla_absorption_config={"enable":False, "online": False}) - +qeff_model.mla(enable_mla=True, mla_absorption_config={"enable": True, "online": True}) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -45,7 +49,7 @@ past_value = torch.zeros((pad_shape_v), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) - + ckv = torch.zeros((pad_shape_ckv), dtype=torch.float32) k_pe = torch.zeros((pad_shape_k_pe), dtype=torch.float32) x = (ckv, k_pe) @@ -80,9 +84,18 @@ print("Completion:", repr(predicted_string)) -onnx_path = qeff_model.export(prefill_seq_len=1, enable_mla=True)#, mla_absorption_config={"enable":True, "online": False}) -qpc_path = qeff_model.compile(prefill_seq_len=1, ctx_len=128, enable_mla=True, #mla_absorption_config={"enable":True, "online": False}, -mxfp6_matmul=False, mxint8_kv_cache=False, num_devices=2, num_cores=16) +onnx_path = qeff_model.export( + prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable": True, "online": True} +) +qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=1024, + enable_mla=True, + mla_absorption_config={"enable": True, "online": True}, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=1, + num_cores=16, +) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) - diff --git a/examples/run_orig_kimi_k2.py b/examples/run_orig_kimi_k2.py index 695377ca06..558329fbfb 100644 --- a/examples/run_orig_kimi_k2.py +++ b/examples/run_orig_kimi_k2.py @@ -1,10 +1,12 @@ -import numpy as np import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers import AutoModelForCausalLM, AutoTokenizer -from QEfficient import QEFFAutoModelForCausalLM - -model = AutoModelForCausalLM.from_pretrained("/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", + torch_dtype=torch.float32, + num_hidden_layers=2, + trust_remote_code=True, +) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) prompt = "Once upon a time," From 0517d936c5aad09535a8cf6f23fd1ce74bfee042 Mon Sep 17 00:00:00 2001 From: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:56:55 +0530 Subject: [PATCH 09/51] Split kv_a_proj_with_mqa weights to get ckv and k_pe Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek_qeff.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 9e924f4253..0c44070b70 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -237,9 +237,9 @@ def __qeff_init__( fusedqk = torch.bmm(per_head_q_up, per_head_k_up) # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - - # self.kv_a_proj_with_mqa_ckv = nn.Linear(self.hidden_size, self.config.kv_lora_rank, bias=self.config.attention_bias) - # self.kv_a_proj_with_mqa_k_pe = nn.Linear(self.hidden_size, self.config.qk_rope_head_dim, bias=self.config.attention_bias) + kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + self.kv_a_proj_with_mqa_ckv = torch.nn.Parameter(kv_a_proj_with_mqa_ckv.detach().clone()) + self.kv_a_proj_with_mqa_k_pe = torch.nn.Parameter(kv_a_proj_with_mqa_k_pe.detach().clone()) def fused_forward( self, @@ -258,10 +258,8 @@ def fused_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - # compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states) - # k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + compressed_kv = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_ckv) + k_pe = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_k_pe) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) From 1398388c9ff061b22e90ab8180756f8cb14d40a8 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 1 Apr 2026 16:12:47 +0000 Subject: [PATCH 10/51] update min_masked_attention_value, head blocking, attn in fused_forward, prefillonly transform Signed-off-by: Mamta Singh --- .../deepseek_v3/modeling_deepseek_qeff.py | 87 ++++++++++++------- .../transformers/models/modeling_auto.py | 12 +++ .../transformers/models/pytorch_transforms.py | 32 +++++-- 3 files changed, 89 insertions(+), 42 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 0c44070b70..f5b67460e5 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -1,6 +1,7 @@ import math from typing import Dict, List, Optional, Tuple, Union +import os import torch import torch.nn.functional as F from torch import nn @@ -25,6 +26,7 @@ # logger, # ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE def rotate_half(x): @@ -234,7 +236,14 @@ def __qeff_init__( # self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False) self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) - fusedqk = torch.bmm(per_head_q_up, per_head_k_up) + + out = torch.matmul(per_head_q_up[0,:,:], per_head_k_up[0,:,:]) + for i in range(1, self.num_heads): + x = torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:]) + out = torch.cat((out,x), 0) + fusedqk = out.reshape(self.num_heads, -1, self.kv_lora_rank) + + #fusedqk = torch.bmm(per_head_q_up, per_head_k_up) # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -290,30 +299,39 @@ def fused_forward( else: enable_absorption = False - if enable_absorption: - if absorb_online: - print("online absorption") - atn = torch.matmul( - torch.matmul(q_a_proj_out.unsqueeze(1), torch.bmm(self.per_head_q_up, self.per_head_k_up)), - kva.transpose(1, 2).unsqueeze(1), - ) + x = [] + for i in range(self.num_heads): + if enable_absorption: + if absorb_online: + if i==0: + print("online absorption") + out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:]) + out = out.reshape(1, -1, self.kv_lora_rank) + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + if i==0: + print("using fused qk") + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:]) + + out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1) + kva_kpe = torch.cat((kva,k_pe.squeeze(1)), -1) + attn_weights = torch.matmul(out3, kva_kpe.transpose(1, 2).unsqueeze(1)) * self.softmax_scale else: - print("using fused qk") - atn = torch.matmul( - torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk), kva.transpose(1, 2).unsqueeze(1) - ) - else: - print("no absorption") - atn = torch.matmul(q_nope, k_nope.transpose(2, 3)) + if i==0: + print("no absorption") + query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1) + key_states = torch.cat((k_nope[:,i,:,:].unsqueeze(1), k_pe), -1) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - atr = torch.matmul(q_pe, k_pe.expand(-1, self.num_heads, -1, -1).transpose(2, 3)) - attn_weights = (atn + atr) * self.softmax_scale + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, value_states[:,i,:,:]) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, value_states) + x.append(attn_output) + + attn_output = torch.cat(x, dim=1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -356,7 +374,7 @@ def forward( q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = torch.cat((q_nope, q_pe), -1) - k_pe_new = k_pe.repeat(1, self.num_heads, 1, 1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) key_states = torch.cat((k_nope, k_pe_new), -1) if past_key_value is not None: @@ -366,7 +384,7 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) @@ -459,16 +477,18 @@ class QEffPrefillOnlyDeepseekV3MoE(nn.Module): def __qeff_init__( self, ): - self.all_gate_proj = torch.nn.Parameter( - torch.cat([exp.gate_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) - ) - self.all_up_proj = torch.nn.Parameter( - torch.cat([exp.up_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) - ) - self.all_down_proj = torch.nn.Parameter( - torch.cat([exp.down_proj.weight.T.unsqueeze(0) for exp in self.experts], dim=0) - ) - self.act_fn = self.experts[0].act_fn + for exp in self.experts: + gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + + gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) + up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) + down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + + setattr(exp,"gate_proj", gate_proj) + setattr(exp,"up_proj", up_proj) + setattr(exp,"down_proj", down_proj) def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) @@ -481,6 +501,7 @@ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_ma current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) final_hidden_states += current_hidden_states + print("\n\ninside prefill only moe\n") return final_hidden_states.type(hidden_states.dtype) def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 257d50009e..d99ff2ad28 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,10 +51,12 @@ KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyExternalModuleMapperTransform, PrefillOnlyChunkedTransform, PrefillOnlyTransform, RevertPrefillKeepAttentionTransform, RevertPrefillOnlyTransform, + RevertPrefillOnlyExternalModuleMapperTransform, SamplerTransform, SpDTransform, TextClassificationTransform, @@ -2700,12 +2702,14 @@ def prefill( retain_full_kv: Optional[bool] = False, ): if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) if enable_chunking: self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) else: self.model, tf = PrefillOnlyTransform.apply(self.model) else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) if retain_full_kv: self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) else: @@ -3013,6 +3017,14 @@ def export( self.hash_params["mla_absorption_config"] = mla_absorption_config setattr(self.model.model, "mla_absorption_config", mla_absorption_config) + if self.model.config.model_type in {"kimi_k2"}: + if prefill_only: + self.prefill(enable=True) + self.hash_params["prefill_only"] = True + else: + self.prefill(enable=False) + self.hash_params.pop("prefill_only", None) + # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index cf9350cdfc..71e094c28e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -739,7 +739,6 @@ class PrefillOnlyTransform(ModuleMappingTransform): QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, - QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE, } @@ -749,14 +748,10 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, -<<<<<<< HEAD # Qwen3Moe QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, # Qwen3 VL Moe QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, -======= - QEffDeepseekV3MoE: QEffPrefillOnlyDeepseekV3MoE, ->>>>>>> ba3218c (Add prefill only moe changes from kimik2 branch) } @@ -768,12 +763,8 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, -<<<<<<< HEAD # Qwen3Moe QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, -======= - QEffPrefillOnlyDeepseekV3MoE: QEffDeepseekV3MoE, ->>>>>>> ba3218c (Add prefill only moe changes from kimik2 branch) } @@ -982,6 +973,29 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, } +class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} + _match_string_replace_method = { + "DeepseekV3MoE": { + "forward": QEffPrefillOnlyDeepseekV3MoE.forward, + "moe": QEffPrefillOnlyDeepseekV3MoE.moe, + "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + }, + } + +class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} + _match_string_replace_method = { + "DeepseekV3MoE": { + "forward": QEffDeepseekV3MoE.forward, + "moe": QEffDeepseekV3MoE.moe, + "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, + }, + } + '''_match_string_replace_method = { + **{v: k for k, v in PrefillOnlyExternalModuleMapperTransform._match_string_replace_method.items()}, + } + ''' class T5ModelTransform(ModuleMappingTransform): # supported architectures From 8aaa56a07e822bee872852f0e27fef8c01387f5c Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 1 Apr 2026 16:16:11 +0000 Subject: [PATCH 11/51] Add replicatekvhead transform Signed-off-by: Mamta Singh --- QEfficient/transformers/cache_utils.py | 6 +- .../deepseek_v3/modeling_deepseek_qeff.py | 121 ++++++++---------- .../transformers/models/modeling_auto.py | 12 +- .../transformers/models/pytorch_transforms.py | 75 +++++++++++ examples/run_kimik2.py | 40 +++--- 5 files changed, 164 insertions(+), 90 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index e3d88f706f..97e4c56c02 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -352,10 +352,10 @@ def update_ckv(self, compressed_kv, cache_kwargs): position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later - self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv) + self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv) ckv_out = self.ckv - ctx_len = ckv_out.shape[1] + ctx_len = ckv_out.shape[-2] ctx_indices = torch.arange(ctx_len)[None, ...] gather_limit = position_ids.max(1, keepdim=True).values invalid_mask = ctx_indices > gather_limit @@ -365,7 +365,7 @@ def update_ckv(self, compressed_kv, cache_kwargs): invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices) + ckv_out = CtxGatherFunc.apply(ckv_out, ctx_indices, ctx_len) ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) return ckv_out diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index f5b67460e5..ad470f1c62 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -214,26 +214,23 @@ def __qeff_init__( -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - # self.register_buffer("q_up", q_up.detach().clone(), persistent=False) + self.q_up = torch.nn.Parameter(q_up.detach().clone()) q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) - # self.register_buffer("q_rope", q_rope.detach().clone(), persistent=False) + self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) - # self.register_buffer("k_up", k_up.detach().clone(), persistent=False) - # self.register_buffer("v_up", v_up.detach().clone(), persistent=False) + self.k_up = torch.nn.Parameter(k_up.detach().clone()) self.v_up = torch.nn.Parameter(v_up.detach().clone()) per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) per_head_k_up = ( self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) ) - # self.register_buffer("per_head_q_up", per_head_q_up.detach().clone(), persistent=False) - # self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False) self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) @@ -243,12 +240,7 @@ def __qeff_init__( out = torch.cat((out,x), 0) fusedqk = out.reshape(self.num_heads, -1, self.kv_lora_rank) - #fusedqk = torch.bmm(per_head_q_up, per_head_k_up) - # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - self.kv_a_proj_with_mqa_ckv = torch.nn.Parameter(kv_a_proj_with_mqa_ckv.detach().clone()) - self.kv_a_proj_with_mqa_k_pe = torch.nn.Parameter(kv_a_proj_with_mqa_k_pe.detach().clone()) def fused_forward( self, @@ -267,9 +259,9 @@ def fused_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - compressed_kv = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_ckv) - k_pe = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_k_pe) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2) + compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.bmm(q_a_proj_out, self.q_rope) @@ -277,21 +269,12 @@ def fused_forward( q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) - kva = self.kv_a_layernorm(compressed_kv) - k_nope = torch.bmm(kva, self.k_up) - k_nope = k_nope.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - value_states = torch.bmm(kva, self.v_up) - value_states = value_states.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + kva = compressed_kv if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) @@ -299,37 +282,56 @@ def fused_forward( else: enable_absorption = False + n_head_ckv = compressed_kv.shape[1] + p = self.num_heads//n_head_ckv + + value_out = [] + for i in range(n_head_ckv): + value_states_ph = torch.matmul(kva[:,i,:,:], self.v_up[:, :, i*p*self.v_head_dim: (i+1)*p*self.v_head_dim]) + value_states_ph = value_states_ph.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) + value_out.append(value_states_ph) + value_states = torch.cat(value_out, dim=1) + + cos, sin = self.rotary_emb(value_states_ph, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + x = [] - for i in range(self.num_heads): - if enable_absorption: - if absorb_online: - if i==0: - print("online absorption") - out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:]) - out = out.reshape(1, -1, self.kv_lora_rank) - out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out) + for k in range(n_head_ckv): + k_nope = torch.matmul(kva[:,k,:,:], self.k_up[:, :, k*p*self.qk_nope_head_dim: (k+1)*p*self.qk_nope_head_dim]) + k_nope = k_nope.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) + + for i in range(k*p, (k+1)*p): + if enable_absorption: + if absorb_online: + if i==0: + print("online absorption") + out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:]) + out = out.reshape(1, -1, self.kv_lora_rank) + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + if i==0: + print("using fused qk") + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:]) + + out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1) + kva_kpe = torch.cat((kva[:,k,:,:],k_pe[:,k,:,:]), -1).unsqueeze(1) + attn_weights = torch.matmul(out3, kva_kpe.transpose(2,3)) * self.softmax_scale else: if i==0: - print("using fused qk") - out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:]) - - out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1) - kva_kpe = torch.cat((kva,k_pe.squeeze(1)), -1) - attn_weights = torch.matmul(out3, kva_kpe.transpose(1, 2).unsqueeze(1)) * self.softmax_scale - else: - if i==0: - print("no absorption") - query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1) - key_states = torch.cat((k_nope[:,i,:,:].unsqueeze(1), k_pe), -1) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + print("no absorption") + query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1).unsqueeze(1) + key_states = torch.cat((k_nope[:,i%p,:,:], k_pe[:,k,:,:]), -1).unsqueeze(1) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, value_states[:,i,:,:]) - - x.append(attn_output) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, value_states[:,i,:,:]) + x.append(attn_output) attn_output = torch.cat(x, dim=1) @@ -455,23 +457,6 @@ def forward(self, hidden_states): hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states - # def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - # final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - # expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) - # expert_mask = expert_mask.permute(2, 0, 1) - - # for expert_idx in range(len(self.experts)): - # expert = self.experts[expert_idx] - # mask = expert_mask[expert_idx] - # expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None]) - # expert_output = torch.where( - # (topk_weights * mask).sum(1).to(torch.bool)[:, None], - # expert_output, - # torch.tensor(0.0), - # ) - # final_hidden_states = final_hidden_states + expert_output - # return final_hidden_states.type(hidden_states.dtype) - class QEffPrefillOnlyDeepseekV3MoE(nn.Module): def __qeff_init__( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d99ff2ad28..ba0d6f4697 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -54,6 +54,7 @@ PrefillOnlyExternalModuleMapperTransform, PrefillOnlyChunkedTransform, PrefillOnlyTransform, + ReplicateKVHeadTransform, RevertPrefillKeepAttentionTransform, RevertPrefillOnlyTransform, RevertPrefillOnlyExternalModuleMapperTransform, @@ -2804,6 +2805,11 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached + if self.model.config.model_type in {"kimi_k2"}: + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + if replicate_kv_transformed: + self.hash_params["config"] = model.config.to_diff_dict() + # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the @@ -3150,11 +3156,11 @@ def export( output_names = [v for v in output_names if "past" not in v] example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): - ckv = torch.zeros((bs, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) - k_pe = torch.zeros((bs, 1, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) + ckv = torch.zeros((bs, 4, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) + k_pe = torch.zeros((bs, 4, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) example_inputs["compressed_kvs"][i].append(ckv) example_inputs["compressed_kvs"][i].append(k_pe) - dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "ctx_len"} + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} output_names.append(f"compressed_kv.{i}_RetainedState") output_names.append(f"k_pe.{i}_RetainedState") diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 71e094c28e..6246268779 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -10,6 +10,7 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +import torch from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -508,6 +509,7 @@ from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -775,6 +777,79 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): } +class ReplicateKVHeadTransform: + """ + Replicates KV heads in attention modules to match the number of KV heads in the target model. + This transform is used when the source model has fewer KV heads than required in target model. + """ + + def _duplicate_weights_for_linear_layer( + layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int + ): + new_kv_heads = repeat #for mla + + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 + ).view(new_kv_heads * dim, hidden_size) + + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, dim), repeat, 0 + ).view(new_kv_heads * dim) + + def _get_text_model(model): + """ + Determine and return the appropriate text_model from a given model object. + """ + # Check for VLMs + if hasattr(model, "language_model"): + if hasattr(model.language_model, "model"): + return model.language_model.model + else: + return model.language_model + # Check for CausalLMs + if hasattr(model, "model"): + return model.model + + raise AttributeError("No suitable text model found in the provided model.") + + @classmethod + def apply(cls, model: nn.Module, **kwargs) -> nn.Module: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + kwargs: Additional arguments for the transformation. Includes: + - num_kv_heads_repeat: The number of times to repeat the KV heads. + """ + n_repeat = kwargs.pop("num_kv_heads_repeat", 1) + transformed = False + if n_repeat is not None and n_repeat > 1: + text_model = cls._get_text_model(model) + + orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads + new_kv_heads = n_repeat*orig_kv_heads + text_model.config.orig_kv_heads = orig_kv_heads + text_model.config.num_key_value_heads = new_kv_heads + + num_attention_heads = text_model.config.num_attention_heads + hidden_size = text_model.config.hidden_size + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + transformed = True + for block in text_model.layers: + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + attn.num_key_value_heads = new_kv_heads + head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim + + cls._duplicate_weights_for_linear_layer( + attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size + ) + return model, transformed + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 0df64cceba..3f6841eefa 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -5,14 +5,18 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," +num_kv_heads_repeat=4 #TS=4 +num_hidden_layers=2 +enable_mla=True +mla_absorption_config={"enable": False, "online": False} -model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +#model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +model_path ="/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True + model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) - PREFILL_SEQ_LEN = 32 CTX_LEN = 128 generation_len = 10 @@ -27,8 +31,8 @@ out = model(**inputs) predictions = torch.argmax(out.logits, dim=-1) -qeff_model = QEFFAutoModelForCausalLM(model) -qeff_model.mla(enable_mla=True, mla_absorption_config={"enable": True, "online": True}) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat) +qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -38,8 +42,8 @@ cache_len = 128 pad_shape_k = (1, 64, cache_len, 192) pad_shape_v = (1, 64, cache_len, 128) -pad_shape_ckv = (1, cache_len, 512) -pad_shape_k_pe = (1, 1, cache_len, 64) +pad_shape_ckv = (1, num_kv_heads_repeat, cache_len, 512) +pad_shape_k_pe = (1, num_kv_heads_repeat, cache_len, 64) past_key_values = [] compressed_kvs = [] @@ -57,9 +61,11 @@ inputs["compressed_kvs"] = compressed_kvs +#inputs["past_key_values"] = past_key_values prefill_qeff_out = qeff_model.model(**inputs) +breakpoint() assert (prefill_qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 position_ids = inputs["position_ids"] @@ -73,6 +79,7 @@ "input_ids": next_token_id, "position_ids": position_ids, "compressed_kvs": qeff_out["past_key_values"], + #"past_key_values": qeff_out["past_key_values"], } qeff_out = qeff_model.model(**decode_inputs) @@ -84,18 +91,19 @@ print("Completion:", repr(predicted_string)) -onnx_path = qeff_model.export( - prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable": True, "online": True} -) +prefill_seq_len = 1 +ctx_len = 1024 + qpc_path = qeff_model.compile( - prefill_seq_len=1, - ctx_len=1024, - enable_mla=True, - mla_absorption_config={"enable": True, "online": True}, - mxfp6_matmul=True, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + enable_mla=enable_mla, + mla_absorption_config=mla_absorption_config, + mxfp6_matmul=False, mxint8_kv_cache=False, - num_devices=1, + num_devices=num_kv_heads_repeat, num_cores=16, + #prefill_only=True, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From b05ca6ef704b7188fb4de1cc71e63700d6111881 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 1 Apr 2026 16:23:58 +0000 Subject: [PATCH 12/51] Add KV blocking and update head blocking Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 67 ++ QEfficient/blocking/attention_blocking.py | 126 ++++ .../blocking/blocked_attention_forwards.py | 576 ++++++++++++++++++ QEfficient/blocking/blocking_configurator.py | 259 ++++++++ .../deepseek_v3/modeling_deepseek_qeff.py | 78 ++- .../transformers/models/modeling_auto.py | 59 +- .../transformers/models/pytorch_transforms.py | 55 +- QEfficient/utils/__init__.py | 2 + QEfficient/utils/_utils.py | 17 + QEfficient/utils/constants.py | 4 +- QEfficient/utils/export_utils.py | 1 + QEfficient/utils/hash_utils.py | 8 + examples/run_kimik2.py | 15 +- 13 files changed, 1194 insertions(+), 73 deletions(-) create mode 100644 QEfficient/blocking/attention_blocking.py create mode 100644 QEfficient/blocking/blocked_attention_forwards.py create mode 100644 QEfficient/blocking/blocking_configurator.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b6da4dcf37..a2eeef3b0a 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -26,15 +26,19 @@ ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile +from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.models.pytorch_transforms import BlockingAttentionTransform from QEfficient.utils import ( constants, create_json, create_model_params, dump_qconfig, generate_mdp_partition_config, + get_attr_or_key, hash_dict_params, load_json, + require_value, ) from QEfficient.utils.export_utils import export_wrapper @@ -365,6 +369,47 @@ def get_onnx_path( self.export(**kwargs) return self.onnx_path + + def transform( + self, + ctx_len: Optional[int] = None, + seq_len: Optional[int] = None, + bs: Optional[int] = 1, + num_devices: int = 1, + qaic_config: Optional[dict] = None, + **compiler_options, + ): + # Apply the transformations that are dependent on compilation parameters + + qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) + + if getattr(self.model, "config", None) or getattr(self.model.model, "config", None): + blocking_config = build_transformer_blocking_config_for_transform( + getattr(self.model, "config", None) + if getattr(self.model, "config", None) + else getattr(self.model.model, "config", None), + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + else: + # without a model config, this is not a model that is possible to block + blocking_config = None + + if blocking_config is not None: + self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config) + blocking_kwargs = self.hash_params.setdefault("blocking_kwargs", {}) + if blocking_config.num_kv_blocks: + blocking_kwargs["num_kv_blocks"] = blocking_config.num_kv_blocks + if blocking_config.num_q_blocks: + blocking_kwargs["num_q_blocks"] = blocking_config.num_q_blocks + if blocking_config.head_block_size: + blocking_kwargs["head_block_size"] = blocking_config.head_block_size + + @dump_qconfig def _compile( self, @@ -383,6 +428,10 @@ def _compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, + disable_blocking: Optional[bool] = True, + blocking_mode: Optional[str] = "hqkv", + vtcm_ratio: Optional[float] = 0.75, + qaic_config: Optional[dict] = None, enable_mla: Optional[bool] = False, mla_absorption_config: Optional[Dict[str, bool]] = False, **compiler_options, @@ -410,6 +459,24 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + + # Transform before export + qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) + bs = require_value(get_attr_or_key(specializations[0], ("batch_size", "batch")), "batch size") + seq_len = get_attr_or_key(specializations[0], ("cl", "seq_len", "sequence_length")) + ctx_len = get_attr_or_key(specializations[0], ("ctx_len", "context_length")) + self.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=mdp_ts_num_devices, + disable_blocking=disable_blocking, + blocking_mode=blocking_mode, + vtcm_ratio=vtcm_ratio, + qaic_config=qaic_config, + **compiler_options, + ) + onnx_path = Path( onnx_path if onnx_path diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py new file mode 100644 index 0000000000..da530d6c8b --- /dev/null +++ b/QEfficient/blocking/attention_blocking.py @@ -0,0 +1,126 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional + +import torch +from transformers.cache_utils import Cache + +from QEfficient.blocking.blocked_attention_forwards import ( + blocked_h_attention_forward, + blocked_hqkv_attention_forward, + blocked_kv_attention_forward, + blocked_q_attention_forward, + blocked_qkv_attention_forward, + invalid_blocking_attention_forward, +) + + +class BlockingMode(str, Enum): + NONE = "" + KV = "kv" + Q = "q" + H = "h" + QKV = "qkv" + HQKV = "hqkv" + + +@dataclass +class AttentionBlockingConfig: + mode: BlockingMode = BlockingMode.NONE + num_kv_blocks: Optional[int] = None + num_q_blocks: Optional[int] = None + head_block_size: Optional[int] = None + + +def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: + return past_key_value is not None and hasattr(past_key_value, "read_only_blockedKV") + + +_STRATEGIES: Dict[BlockingMode, Callable] = { + BlockingMode.NONE: invalid_blocking_attention_forward, + BlockingMode.KV: blocked_kv_attention_forward, + BlockingMode.Q: blocked_q_attention_forward, + BlockingMode.H: blocked_h_attention_forward, + BlockingMode.QKV: blocked_qkv_attention_forward, + BlockingMode.HQKV: blocked_hqkv_attention_forward, +} + + +def get_blocking_strategy(config: AttentionBlockingConfig) -> Callable: + return _STRATEGIES.get(config.mode, _STRATEGIES[BlockingMode.NONE]) + + +def generic_blocked_attention_interface( + module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + layer_idx: int, + past_key_value: Cache, + blocking_config: AttentionBlockingConfig, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_seen_tokens: Optional[int] = None, + non_blocked_forward: Callable = None, + **kwargs, +): + use_kv_blocked = ( + blocking_config is not None and "kv" in blocking_config.mode and supports_blocked_kv(past_key_value) + ) + use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) + + if past_key_value is not None: + if use_kv_blocked: + cache_kwargs = { + "batch_index": batch_index, + "position_ids": position_ids, + "past_seen_tokens": past_seen_tokens, + } + past_key_value.write_only(key, value, module.layer_idx, cache_kwargs) + else: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key, value = past_key_value.update(key, value, module.layer_idx, cache_kwargs) + + if use_blocking: + strategy = get_blocking_strategy(blocking_config) + attn_output, attn_weights = strategy( + module=module, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + scaling=scaling, + cache_kwargs=cache_kwargs, + layer_idx=layer_idx, + past_key_value=past_key_value, + num_kv_blocks=blocking_config.num_kv_blocks, + num_q_blocks=blocking_config.num_q_blocks, + head_block_size=blocking_config.head_block_size, + ) + else: + attn_output, attn_weights = non_blocked_forward( + module, + query, + key, + value, + attention_mask, + scaling=scaling, + **kwargs, + ) + + return attn_output, attn_weights diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py new file mode 100644 index 0000000000..37ae942f52 --- /dev/null +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -0,0 +1,576 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from __future__ import annotations + +import math +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from torch import nn +from transformers.cache_utils import Cache + +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep) for GQA. + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _get_kv_states(module: nn.Module, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num_kv_groups = getattr(module, "num_key_value_groups", None) + if num_kv_groups is None: + return key, value + return repeat_kv(key, num_kv_groups), repeat_kv(value, num_kv_groups) + + +def _normalize_int(value: Optional[torch.Tensor | int]) -> int: + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) if value is not None else 0 + + +def blocked_kv_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize result tensor + output = torch.zeros_like(query) + + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, _ = query.shape + current_max = torch.full( + (batch_size, num_heads, seq_len), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device) + + past_seen_tokens = _normalize_int(cache_kwargs.get("past_seen_tokens")) + total_seen_tokens = past_seen_tokens + query.shape[2] + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + num_kv_blocks = _normalize_int(num_kv_blocks) + kv_block_positions = [(i * past_seen_tokens) // num_kv_blocks for i in range(num_kv_blocks)] + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) + + current_position = position_ids.max(dim=-1).values + + for j in range(num_kv_blocks): + start_index = kv_block_positions[j] + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_positions[j + 1] - start_index + end_index = start_index + kv_len_block + + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + + k_block, v_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + attn_weights_block = torch.matmul(query, k_block_states.transpose(2, 3)) * scaling + if score_mod is not None: + attn_weights_block = score_mod(attn_weights_block, start_index, end_index) + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + target_length = min(total_seen_tokens, end_index) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + attn_weights_block = torch.where(mask_block, masked_tensor, attn_weights_block) + + # Update Running row maximum + prev_max = current_max + current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=-1).values) + delta_max = prev_max - current_max_updated + + current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) + + # update running denominator + prev_denominator = current_denominator + curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) + current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum + + prob = current_exp / current_denominator_updated.unsqueeze(-1) + + prev_output = output + output_updated = ((prev_denominator / current_denominator_updated).unsqueeze(-1)) * prev_output * torch.exp( + delta_max.unsqueeze(-1) + ) + torch.matmul(prob, v_block_states) + + if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + current_max = torch.where(skip_future, prev_max, current_max_updated) + current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) + output = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) + else: + # Eager mode + current_max = current_max_updated + current_denominator = current_denominator_updated + output = output_updated + + attn_output = output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights + + +def blocked_qkv_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + num_q_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, DH = query.shape + + past_seen_tokens = _normalize_int(cache_kwargs.get("past_seen_tokens")) + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + num_kv_blocks = _normalize_int(num_kv_blocks) + num_q_blocks = max(1, _normalize_int(num_q_blocks)) + + q_block_positions = [(i * seq_len) // num_q_blocks for i in range(num_q_blocks)] + q_output_blocks = [] + q_attn_blocks = [] + + kv_block_positions = [(i * past_seen_tokens) // num_kv_blocks for i in range(num_kv_blocks)] + + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) + + current_position = position_ids.max(dim=-1).values + + for q_block_idx in range(num_q_blocks): + q_start = q_block_positions[q_block_idx] + if q_block_idx == num_q_blocks - 1: + q_len_block = seq_len - q_start + else: + q_len_block = q_block_positions[q_block_idx + 1] - q_start + + q_block = query[:, :, q_start : q_start + q_len_block, :] + + current_max = torch.full( + (batch_size, num_heads, q_len_block), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_size, num_heads, q_len_block, device=query.device) + output_blocks = torch.zeros((batch_size, num_heads, q_len_block, DH), device=query.device, dtype=query.dtype) + + for j in range(num_kv_blocks): + start_index = kv_block_positions[j] + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_positions[j + 1] - start_index + end_index = start_index + kv_len_block + + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + k_block, v_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + attn_weights_block = torch.matmul(q_block, k_block_states.transpose(2, 3)) * scaling + if score_mod is not None: + attn_weights_block = score_mod(attn_weights_block, start_index, end_index) + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + # target_length = min(total_seen_tokens, end_index) + target_length = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + attn_mask_block = mask_block[:, :, q_start : q_start + q_len_block, :] + attn_weights_block = torch.where(attn_mask_block, masked_tensor, attn_weights_block) + + # Update Running row maximum + prev_max = current_max + current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=3).values) + delta_max = prev_max - current_max_updated + + current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) + + # update running denominator + prev_denominator = current_denominator + curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) + current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum + + prob = current_exp / current_denominator_updated.unsqueeze(-1) + + prev_output = output_blocks + output_updated = ((prev_denominator / current_denominator_updated).unsqueeze(-1)) * prev_output * torch.exp( + delta_max.unsqueeze(-1) + ) + torch.matmul(prob, v_block_states) + + if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + current_max = torch.where(skip_future, prev_max, current_max_updated) + current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) + output_blocks = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) + else: + # Eager mode + current_max = current_max_updated + current_denominator = current_denominator_updated + output_blocks = output_updated + q_output_blocks.append(output_blocks) + q_attn_blocks.append(attn_weights_block) + + attn_output = torch.cat(q_output_blocks, dim=2).transpose(1, 2).contiguous() + attn_weights = torch.cat(q_attn_blocks, dim=2) + + return attn_output, attn_weights + + +def blocked_hqkv_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + num_q_blocks: int, + head_block_size: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, DH = query.shape + + past_seen_tokens = _normalize_int(cache_kwargs.get("past_seen_tokens")) + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + num_kv_blocks = _normalize_int(num_kv_blocks) + if head_block_size <= 0: + head_block_size = num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + num_q_blocks = max(1, _normalize_int(num_q_blocks)) + + q_block_positions = [(i * seq_len) // num_q_blocks for i in range(num_q_blocks)] + + h_output_blocks = [] + h_attn_blocks = [] + + kv_block_positions = [(i * past_seen_tokens) // num_kv_blocks for i in range(num_kv_blocks)] + + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) + + current_position = position_ids.max(dim=-1).values + + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, num_heads) + + # Extract head blocks + q_g = query[:, h_start:h_end, :, :] + + q_output_blocks = [] + q_attn_blocks = [] + + for q_block_idx in range(num_q_blocks): + q_start = q_block_positions[q_block_idx] + if q_block_idx == num_q_blocks - 1: + q_len_block = seq_len - q_start + else: + q_len_block = q_block_positions[q_block_idx + 1] - q_start + + q_block = q_g[:, :, q_start : q_start + q_len_block, :] + + current_max = torch.full( + (batch_size, h_end - h_start, q_len_block), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_size, h_end - h_start, q_len_block, device=query.device) + output_blocks = torch.zeros( + (batch_size, h_end - h_start, q_len_block, DH), device=query.device, dtype=query.dtype + ) + + for j in range(num_kv_blocks): + start_index = kv_block_positions[j] + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_positions[j + 1] - start_index + end_index = start_index + kv_len_block + + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + k_block, v_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + k_g = k_block_states[:, h_start:h_end, :, :] + v_g = v_block_states[:, h_start:h_end, :, :] + + attn_weights_block = torch.matmul(q_block, k_g.transpose(2, 3)) * scaling + if score_mod is not None: + attn_weights_block = score_mod(attn_weights_block, start_index, end_index) + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + # target_length = min(total_seen_tokens, end_index) + target_length = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + mask_block_g = mask_block[:, :, q_start : q_start + q_len_block, :] + attn_weights_block = torch.where(mask_block_g, masked_tensor, attn_weights_block) + + # Update Running row maximum + prev_max = current_max + current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=3).values) + delta_max = prev_max - current_max_updated + + current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) + + # update running denominator + prev_denominator = current_denominator + curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) + current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum + + prob = current_exp / current_denominator_updated.unsqueeze(-1) + + prev_output = output_blocks + output_updated = ( + (prev_denominator / current_denominator_updated).unsqueeze(-1) + ) * prev_output * torch.exp(delta_max.unsqueeze(-1)) + torch.matmul(prob, v_g) + + if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + # skip_mask = skip_future.view(1, 1, 1).expand(batch_size, h_end - h_start, q_len_block) + current_max = torch.where(skip_future, prev_max, current_max_updated) + current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) + output_blocks = torch.where(skip_future, prev_output, output_updated) + else: + # Eager mode + current_max = current_max_updated + current_denominator = current_denominator_updated + output_blocks = output_updated + q_output_blocks.append(output_blocks) + q_attn_blocks.append(attn_weights_block) + + head_output = torch.cat(q_output_blocks, dim=2) + head_attn_weights = torch.cat(q_attn_blocks, dim=2) + h_output_blocks.append(head_output) + h_attn_blocks.append(head_attn_weights) + + attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() + attn_weights = torch.cat(h_attn_blocks, dim=1) + + return attn_output, attn_weights + + +def blocked_h_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + head_block_size: int, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Q-blocked attention that slices the query sequence into blocks and processes each block. + """ + batch_size, num_heads, q_len, _ = query.shape + if head_block_size <= 0: + head_block_size = num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + + key_states, value_states = _get_kv_states(module, key, value) + + h_output_blocks = [] + h_attn_blocks = [] + + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) + + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, num_heads) + + # Extract head blocks + q_g = query[:, h_start:h_end, :, :] + k_g = key_states[:, h_start:h_end, :, :] + v_g = value_states[:, h_start:h_end, :, :] + + attn_weights = torch.matmul(q_g, k_g.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where(attention_mask, masked_tensor, attn_weights) + + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + output_block = torch.matmul(attn_weights, v_g) + + h_output_blocks.append(output_block) + h_attn_blocks.append(attn_weights) + + attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() + attn_weights = torch.cat(h_attn_blocks, dim=1) + + return attn_output, attn_weights + + +def blocked_q_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_q_blocks: int, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Q-blocked attention that slices the query sequence into blocks and processes each block. + """ + batch_size, num_heads, q_len, _ = query.shape + num_q_blocks = max(1, _normalize_int(num_q_blocks)) + key_states, value_states = _get_kv_states(module, key, value) + + q_block_positions = [(i * q_len) // num_q_blocks for i in range(num_q_blocks)] + q_output_blocks = [] + q_attn_blocks = [] + + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) + + for q_block_idx in range(num_q_blocks): + q_start = q_block_positions[q_block_idx] + if q_block_idx == num_q_blocks - 1: + q_len_block = q_len - q_start + else: + q_len_block = q_block_positions[q_block_idx + 1] - q_start + + q_block = query[:, :, q_start : q_start + q_len_block, :] + attn_mask_block = None + if attention_mask is not None: + attn_mask_block = attention_mask[:, :, q_start : q_start + q_len_block, :] + + attn_weights = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + if attn_mask_block is not None: + attn_weights = torch.where(attn_mask_block, masked_tensor, attn_weights) + + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + output_block = torch.matmul(attn_weights, value_states) + + q_output_blocks.append(output_block) + q_attn_blocks.append(attn_weights) + + attn_output = torch.cat(q_output_blocks, dim=2).transpose(1, 2).contiguous() + attn_weights = torch.cat(q_attn_blocks, dim=2) + + return attn_output, attn_weights + + +def invalid_blocking_attention_forward(*args, **kwargs): + raise NotImplementedError("Invalid blocking strategy was selected") diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py new file mode 100644 index 0000000000..a00a2bdc31 --- /dev/null +++ b/QEfficient/blocking/blocking_configurator.py @@ -0,0 +1,259 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Utility helpers to suggest attention/FFN blocking configs for diffusers transformers and transformers + +This module adapts the standalone configurator script into a clean, importable API +that can be fed model config + pipeline compile config to derive blocking settings. +""" + +from __future__ import annotations + +import math +from typing import Any, Dict, List, Optional + +from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, BlockingMode +from QEfficient.utils import get_attr_or_key, require_value +from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD + + +def _infer_head_dim(model_config: Any, num_heads: int) -> int: + head_dim = get_attr_or_key(model_config, ("attention_head_dim", "head_dim", "head_dim_per_head")) + if head_dim is not None: + return int(head_dim) + hidden_size = get_attr_or_key(model_config, ("hidden_size", "d_model", "model_dim", "attention_dim")) + if hidden_size is None: + raise ValueError("Missing head_dim or hidden_size to compute attention blocking configuration.") + return int(hidden_size) // int(num_heads) + + +def _infer_data_bytes(compile_config: Dict[str, Any]) -> int: + explicit = compile_config.get("data_bytes") + if explicit is not None: + return int(explicit) + if compile_config.get("convert_to_fp16", False): + return 2 + return 4 + + +def _normalize_attention_mode(raw_mode: str) -> str: + mode = raw_mode.lower() + if "q" in mode and "kv" in mode: + return "qkv" + if "kv" in mode: + return "kv" + if "q" in mode: + return "q" + return "" + + +def _resolve_effective_blocking_mode(attention_cfg: Dict[str, Any], requested_mode: str) -> str: + mode = _normalize_attention_mode(requested_mode) + if mode == "": + return "" + num_q_blocks = attention_cfg.get("num_q_blocks") or 1 + num_kv_blocks = attention_cfg.get("num_kv_blocks") or 1 + if num_q_blocks > 1 and num_kv_blocks > 1: + return "qkv" + if num_q_blocks > 1: + return "q" + if num_kv_blocks > 1: + return "kv" + return "" + + +def _get_valid_num_blocks(config: Dict, requested_key: str) -> int: + if config.get(requested_key) < 1: + raise ValueError(f"Invalid value {requested_key} passed in qaic_config: {config.get(requested_key)}") + return config.get(requested_key) + + +def block_candidates_generator(max_length: int) -> List[int]: + block_list = [] + i = 1 + step = 1 + while i <= max_length: + block_list.append(i) + if i % (4 * step) == 0: + step *= 2 + i += step + return block_list + + +def attention_configurator( + bs: int, + seq_len: int, + ctx_len: int, + num_heads: int, + head_dim: int, + num_socs: int, + num_nsps: int, + data_bytes: int, + blocking_mode: Optional[str] = None, +) -> Dict[str, Any]: + """ + Suggest attention blocking configuration based on model and device constraints. + """ + mode = (blocking_mode or "hqkv").lower() + + num_kv_blocks_list = block_candidates_generator(ctx_len) if "kv" in mode else [1] + num_q_blocks_list = block_candidates_generator(ctx_len) if "q" in mode else [1] + + head_block_size = num_socs if "h" in mode else num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + num_heads_per_iter = math.ceil(head_block_size / num_socs) + + best_config = { + "head_block_size": head_block_size, + "num_head_blocks": num_head_blocks, + "head_blocking_enabled": num_head_blocks > 1, + "num_q_blocks": None, + "num_kv_blocks": None, + "q_kv_ratio": None, + "vtcm_footprint": None, + } + + def update_best_config(num_q_blocks: int, num_kv_blocks: int, q_kv_ratio: float, footprint: float) -> None: + best_config["num_q_blocks"] = num_q_blocks + best_config["num_kv_blocks"] = num_kv_blocks + best_config["q_kv_ratio"] = q_kv_ratio + best_config["vtcm_footprint"] = footprint + + for num_q_blocks in num_q_blocks_list: + for num_kv_blocks in num_kv_blocks_list: + q_sl_per_nsp = math.ceil(seq_len / num_nsps / num_q_blocks) + q_size_per_nsp = num_heads_per_iter * bs * q_sl_per_nsp * head_dim * data_bytes + + kv_cl_per_nsp = math.ceil(ctx_len / num_kv_blocks) + kv_size_per_nsp = num_heads_per_iter * bs * kv_cl_per_nsp * head_dim * data_bytes + + qk_size_per_nsp = num_heads_per_iter * bs * q_sl_per_nsp * kv_cl_per_nsp * data_bytes + vtcm_footprint = q_size_per_nsp + kv_size_per_nsp + qk_size_per_nsp + q_kv_ratio = max(q_size_per_nsp / kv_size_per_nsp, kv_size_per_nsp / q_size_per_nsp) + num_total_blocks = num_q_blocks * num_kv_blocks + + if vtcm_footprint < VTCM_SIZE_THRESHOLD: + if best_config["num_q_blocks"] is None: + update_best_config(num_q_blocks, num_kv_blocks, q_kv_ratio, vtcm_footprint) + elif best_config["num_q_blocks"] * best_config["num_kv_blocks"] > num_total_blocks: + update_best_config(num_q_blocks, num_kv_blocks, q_kv_ratio, vtcm_footprint) + elif ( + best_config["num_q_blocks"] * best_config["num_kv_blocks"] == num_total_blocks + and best_config["q_kv_ratio"] >= q_kv_ratio + ): + update_best_config(num_q_blocks, num_kv_blocks, q_kv_ratio, vtcm_footprint) + break + + return best_config + + +def build_transformer_blocking_config( + model_config: Any, + pipeline_config: Optional[Any] = None, + module_name: str = "transformer", + blocking_mode: Optional[str] = None, + ctx_len: Optional[int] = None, + seq_len: Optional[int] = None, + bs: Optional[int] = 1, + compile_config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Build blocking configuration based on model config + pipeline compile config. + """ + if ctx_len is None: + ctx_len = seq_len + + if seq_len is None and ctx_len is None: + return AttentionBlockingConfig(mode="") + + num_heads = require_value( + get_attr_or_key(model_config, ("num_attention_heads", "num_heads", "attention_heads", "n_heads")), + "num attention heads", + ) + head_dim = _infer_head_dim(model_config, int(num_heads)) + + num_socs = int(compile_config.get("mdp_ts_num_devices", 1)) + num_nsps = int(compile_config.get("aic_num_cores", 1)) + data_bytes = _infer_data_bytes(compile_config) + + attention_cfg = attention_configurator( + int(bs), + int(seq_len), + int(ctx_len), + int(num_heads), + int(head_dim), + int(num_socs), + int(num_nsps), + int(data_bytes), + blocking_mode=blocking_mode, + ) + + resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") + effective_mode = _resolve_effective_blocking_mode(attention_cfg, resolved_mode) + + return AttentionBlockingConfig( + mode=effective_mode, + num_kv_blocks=attention_cfg["num_kv_blocks"], + num_q_blocks=attention_cfg["num_q_blocks"], + head_block_size=attention_cfg["head_block_size"], + ) + + +def build_transformer_blocking_config_for_transform( + model_config: Any, + ctx_len: Optional[int] = None, + seq_len: Optional[int] = None, + bs: Optional[int] = 1, + num_devices: Optional[int] = 1, + qaic_config: Optional[dict] = None, + **compile_options, +) -> Dict[str, Any]: + + if qaic_config: + blocking_mode = BlockingMode(qaic_config.get("blocking_mode", "hqkv")) + else: + blocking_mode = BlockingMode.HQKV + enable_blocking = False if not qaic_config else qaic_config.get("enable_blocking", False) + + if qaic_config is None and enable_blocking: + blocking_config = build_transformer_blocking_config( + model_config, + blocking_mode=blocking_mode, + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + compile_config={"mdp_ts_num_devices": num_devices, **compile_options}, + ) + elif not enable_blocking: + blocking_config = None + else: + blocking_config = AttentionBlockingConfig() + mode_from_config = "" + if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode: + mode_from_config = "kv" + mode_from_config + blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") + if qaic_config.get("num_q_blocks", False) and enable_blocking and "q" in blocking_mode: + mode_from_config = "q" + mode_from_config + blocking_config.num_q_blocks = _get_valid_num_blocks(qaic_config, "num_q_blocks") + if qaic_config.get("head_block_size", False) and enable_blocking and "h" in blocking_mode: + mode_from_config = "h" + mode_from_config + blocking_config.head_block_size = _get_valid_num_blocks(qaic_config, "head_block_size") + + # check if qaic config did not provide any blocking details + if mode_from_config == "": + blocking_config = build_transformer_blocking_config( + model_config, + blocking_mode=blocking_mode, + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + compile_config={"mdp_ts_num_devices": num_devices, **compile_options}, + ) + else: + blocking_config.mode = BlockingMode(mode_from_config) + + return blocking_config diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index ad470f1c62..deff7ae7cd 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -8,23 +8,8 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, get_blocking_strategy, supports_blocked_kv from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache - -# Assuming these are imported from the original DeepseekV3 code -# from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( -# DeepseekV3Config, -# DeepseekV3RMSNorm, -# DeepseekV3MLP, -# DeepseekV3MoE, -# rotate_half, -# repeat_kv, -# DeepseekV3Attention, -# DeepseekV3DecoderLayer, -# DeepseekV3Model, -# DeepseekV3ForCausalLM, -# DeepseekV3PreTrainedModel, -# logger, -# ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -234,11 +219,11 @@ def __qeff_init__( self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) - out = torch.matmul(per_head_q_up[0,:,:], per_head_k_up[0,:,:]) - for i in range(1, self.num_heads): - x = torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:]) - out = torch.cat((out,x), 0) - fusedqk = out.reshape(self.num_heads, -1, self.kv_lora_rank) + fusedqk_list = [] + for i in range(self.num_heads): + fusedqk_list.append(torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])) + fusedqk = torch.cat(fusedqk_list, dim=0) + fusedqk = fusedqk.reshape(self.num_heads, -1, self.kv_lora_rank) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) @@ -298,42 +283,51 @@ def fused_forward( if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - x = [] - for k in range(n_head_ckv): - k_nope = torch.matmul(kva[:,k,:,:], self.k_up[:, :, k*p*self.qk_nope_head_dim: (k+1)*p*self.qk_nope_head_dim]) - k_nope = k_nope.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) + breakpoint() + blocking_config = getattr(self, "attn_blocking_config", None) + num_kv_blocks = blocking_config.num_kv_blocks #1 + seq_len = compressed_kv.shape[-2] + block_size = -(-seq_len // num_kv_blocks) #32 - for i in range(k*p, (k+1)*p): + attn_output_list = [] + for k in range(n_head_ckv): + attn_weights_list = [] + for j in range(num_kv_blocks): if enable_absorption: if absorb_online: - if i==0: + if j==0: print("online absorption") - out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:]) - out = out.reshape(1, -1, self.kv_lora_rank) + out = torch.matmul(self.per_head_q_up[k*p:(k+1)*p,:,:], self.per_head_k_up[k*p:(k+1)*p,:,:]) out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out) else: - if i==0: + if j==0: print("using fused qk") - out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:]) + out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[k*p:(k+1)*p,:,:]) - out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1) - kva_kpe = torch.cat((kva[:,k,:,:],k_pe[:,k,:,:]), -1).unsqueeze(1) + out3 = torch.cat((out2, q_pe[:,k*p:(k+1)*p,:,:]), -1) + kva_kpe = torch.cat((kva[:,k,j*block_size:(j+1)*block_size,:],k_pe[:,k,j*block_size:(j+1)*block_size,:]), -1).unsqueeze(1) attn_weights = torch.matmul(out3, kva_kpe.transpose(2,3)) * self.softmax_scale else: - if i==0: + if j==0: print("no absorption") - query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1).unsqueeze(1) - key_states = torch.cat((k_nope[:,i%p,:,:], k_pe[:,k,:,:]), -1).unsqueeze(1) + k_nope = torch.matmul(kva[:,k,:,:], self.k_up[:, :, k*p*self.qk_nope_head_dim: (k+1)*p*self.qk_nope_head_dim]) + k_nope = k_nope.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) + key_states = torch.cat((k_nope[:,:,j*block_size:(j+1)*block_size,:], k_pe[:,k,j*block_size:(j+1)*block_size,:].unsqueeze(1).repeat(1,p,1,1)), -1) + query_states = torch.cat((q_nope[:,k*p:(k+1)*p,:,:], q_pe[:,k*p:(k+1)*p,:,:]), -1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) + attn_weights_list.append(attn_weights) + + attn_weights = torch.cat(attn_weights_list, dim=-1) + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, value_states[:,i,:,:]) - x.append(attn_output) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, value_states[:,k*p:(k+1)*p,:,:]) + attn_output_list.append(attn_output) - attn_output = torch.cat(x, dim=1) + attn_output = torch.cat(attn_output_list, dim=1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ba0d6f4697..8bb0851e2d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -46,7 +46,6 @@ _configure_proxy_for_model, ) from QEfficient.transformers.models.pytorch_transforms import ( - BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, @@ -1370,6 +1369,33 @@ def export( ) return self.onnx_path + def transform( + self, + ctx_len: Optional[int] = None, + seq_len: Optional[int] = None, + bs: Optional[int] = 1, + num_devices: int = 1, + qaic_config: Optional[dict] = None, + **compiler_options, + ): + self.vision_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + + self.lang_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + def compile( self, img_size: Optional[int] = None, @@ -1988,9 +2014,6 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None - if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: - BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - @classmethod def from_pretrained( cls, @@ -2795,8 +2818,14 @@ def __init__( self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config - self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) - self.is_tlm = transformed + self.model.pretrained_path = kwargs.pop("pretrained_model_name_or_path", None) + # self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) + # self.is_tlm = transformed + self.is_tlm = ( + (qaic_config is not None) + and (qaic_config.get("speculative_model_type") is not None) + and (model.__class__ in SpDTransform._module_mapping) + ) self.hash_params["qeff_auto_class"] = self.__class__.__name__ self.ccl_enabled = False @@ -2814,14 +2843,9 @@ def __init__( # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the # previous transform function. - self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) + # self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) # TODO : Update in qaic_config isn't updated in the hash due to SpDTransforms. Need to move # SpDTransforms to PytorchTransforms. - if self.is_tlm: - self.model.qaic_config["return_pdfs"] = True - - if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: - BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -3005,6 +3029,17 @@ def export( """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + # increase seq_len if using a larger number of blocks + if self.hash_params.get("blocking_kwargs", None): + max_blocks = -1 + for num_blocks in self.hash_params.get("blocking_kwargs").values(): + max_blocks = max(max_blocks, num_blocks) + block_size = -(-seq_len // max_blocks) + while seq_len < max_blocks or (seq_len % max_blocks > block_size): + seq_len = seq_len * 2 + block_size = -(-seq_len // max_blocks) + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS # kv_cache_shape = get_padding_shape_from_config( # self.model.config, fbs if self.continuous_batching else bs, seq_len diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 6246268779..39ab5010b4 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1107,21 +1107,50 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu return model, transformed -class BlockedKVAttentionTransform: - _module_mapping = { - QEffLlamaAttention, - QEffQwen2_5_VLAttention, - } +def get_decoder_layer_classes_for_export(model: nn.Module) -> set: + """ + Dynamically determine which DecoderLayer classes should be exported as functions + based on the model's architecture using the existing KVCacheTransform mapping. + """ + # Define patterns that identify decoder layer classes + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] + + # Get all QEff classes that are decoder layers from the existing mapping + decoder_layer_classes = set() + + for original_class, qeff_class in KVCacheTransform._module_mapping.items(): + # Check if the QEff class name contains decoder layer patterns + qeff_class_name = qeff_class.__name__ + if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): + decoder_layer_classes.add(qeff_class) + + # Filter to only include classes that are actually used in the current model + model_decoder_classes = set() + for module in model.modules(): + if module.__class__ in decoder_layer_classes: + model_decoder_classes.add(module.__class__) + + return model_decoder_classes + + +class BlockingAttentionTransform: + _skip_classes = {} @classmethod - def apply(cls, model: nn.Module, num_kv_blocks) -> Tuple[nn.Module, bool]: + def apply(cls, model: nn.Module, attn_blocking_config) -> Tuple[nn.Module, bool]: transformed = False + supported_attention_classes = { + qeff_class + for qeff_class in KVCacheTransform._module_mapping.values() + if qeff_class.__name__.endswith("Attention") + } for module in model.modules(): - if type(module) in cls._module_mapping: - repl_module = type(module) - module.__class__ = repl_module - module.forward = MethodType(partial(repl_module.forward, num_kv_blocks=num_kv_blocks), module) - transformed = True # Set to True if at least one transformation occurs - elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping: - warnings.warn(f"KV blocking is not yet supported for {type(module)}.") + if type(module) in cls._skip_classes: + warnings.warn(f"Blocking is not yet supported for {type(module)}.") + continue + if type(module) in supported_attention_classes or model.config.model_type == "kimi_k2": + module.attn_blocking_config = attn_blocking_config + transformed = True + elif module.__class__.__name__.endswith("Attention") and type(module) not in supported_attention_classes: + warnings.warn(f"Blocking is not yet supported for {type(module)}.") return model, transformed diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index 3d6583f857..d25198f9d6 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -18,6 +18,7 @@ dump_qconfig, generate_mdp_partition_config, get_num_layers_from_config, + get_attr_or_key, get_num_layers_vlm, get_onnx_dir_name, get_padding_shape_from_config, @@ -34,6 +35,7 @@ onnx_exists, padding_check_and_fix, qpc_exists, + require_value, ) from QEfficient.utils.hash_utils import ( # noqa: F401 create_export_hash, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 26bae7a34b..e1b88dfab2 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -837,3 +837,20 @@ def custom_format_warning(msg, category, *args, **kwargs): YELLOW = "\033[93m" RESET = "\033[0m" return f"{YELLOW}[Warning]: {msg}{RESET}\n" + + +def get_attr_or_key(obj: Any, names: Tuple[str, ...], default: Any = None) -> Any: + if obj is None: + return default + for name in names: + if isinstance(obj, dict) and name in obj: + return obj[name] + if hasattr(obj, name): + return getattr(obj, name) + return default + + +def require_value(value: Any, label: str) -> Any: + if value is None: + raise ValueError(f"Missing required {label} to compute blocking configuration.") + return value diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index b3782605e1..cc0b87b604 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -33,6 +33,9 @@ ), } +# Blocking defaults +VTCM_SIZE_THRESHOLD = 8 * 1024 * 1024 * 0.75 + # Compiler defaults DEFAULT_AIC_NUM_CORES = 16 DEFAULT_AIC_MXPF6_MATMUL = False @@ -210,7 +213,6 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 - NUM_KV_BLOCKS = 8 MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS SAMPLER_OPS = { "repetition_penalties", diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index da3231190e..4501b1a932 100644 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -131,6 +131,7 @@ def _generate_export_hash(qeff_model, args, kwargs, func): output_names=all_args.get("output_names"), dynamic_axes=all_args.get("dynamic_axes"), export_kwargs=all_args.get("export_kwargs", None), + blocking_kwargs=all_args.get("blocking_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), ) diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 10e6686d0c..4cb137895e 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -7,6 +7,7 @@ import hashlib import json +from dataclasses import asdict, is_dataclass from typing import Dict from QEfficient.utils.constants import HASH_HEXDIGEST_STR_LEN @@ -16,6 +17,9 @@ def json_serializable(obj): if isinstance(obj, set): # Convert set to a sorted list of strings for consistent hashing return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj]) + if is_dataclass(obj): + # Convert dataclass to dict for serialization + return asdict(obj) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") @@ -58,6 +62,10 @@ def create_export_hash(**kwargs): export_params["dynamic_axes"] = kwargs.get("dynamic_axes") export_hash_params["export_params"] = export_params + blocking_kwargs = export_hash_params.pop("blocking_kwargs", None) + if blocking_kwargs: + export_hash_params.update(blocking_kwargs) + export_kwargs = kwargs.get("export_kwargs") if export_kwargs: export_hash_params.update(export_kwargs) diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 3f6841eefa..b50dc4ed2f 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -39,11 +39,10 @@ inputs.pop("token_type_ids", None) inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -cache_len = 128 -pad_shape_k = (1, 64, cache_len, 192) -pad_shape_v = (1, 64, cache_len, 128) -pad_shape_ckv = (1, num_kv_heads_repeat, cache_len, 512) -pad_shape_k_pe = (1, num_kv_heads_repeat, cache_len, 64) +pad_shape_k = (1, 64, CTX_LEN, 192) +pad_shape_v = (1, 64, CTX_LEN, 128) +pad_shape_ckv = (1, num_kv_heads_repeat, CTX_LEN, 512) +pad_shape_k_pe = (1, num_kv_heads_repeat, CTX_LEN, 64) past_key_values = [] compressed_kvs = [] @@ -91,6 +90,11 @@ print("Completion:", repr(predicted_string)) + + + +qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} + prefill_seq_len = 1 ctx_len = 1024 @@ -104,6 +108,7 @@ num_devices=num_kv_heads_repeat, num_cores=16, #prefill_only=True, + qaic_config=qaic_config, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 2e44321b0f16cc54e0828da94ea53e5f124f84ff Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 1 Apr 2026 16:25:00 +0000 Subject: [PATCH 13/51] Update modeling_deepseek_qeff.py Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek_qeff.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index deff7ae7cd..97cb207218 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -283,11 +283,13 @@ def fused_forward( if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - breakpoint() blocking_config = getattr(self, "attn_blocking_config", None) - num_kv_blocks = blocking_config.num_kv_blocks #1 + num_kv_blocks = 1 + if blocking_config is not None: + num_kv_blocks = blocking_config.num_kv_blocks + print("num_kv_blocks : ", num_kv_blocks) seq_len = compressed_kv.shape[-2] - block_size = -(-seq_len // num_kv_blocks) #32 + block_size = -(-seq_len // num_kv_blocks) attn_output_list = [] for k in range(n_head_ckv): From 8c99ee0a1c1144d6d5246afbde85a8d4fc8d7d00 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 1 Apr 2026 16:27:01 +0000 Subject: [PATCH 14/51] update kv heads in example_inputs during export and kv_blocking Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 4 +++- .../deepseek_v3/modeling_deepseek_qeff.py | 22 ++++++++++--------- .../transformers/models/modeling_auto.py | 5 +++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a2eeef3b0a..d524f65f33 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -289,7 +289,6 @@ def _export( else: input_names.append(param) - # import ipdb; ipdb.set_trace() try: torch.onnx.export( self.model, @@ -348,6 +347,7 @@ def get_onnx_path( retain_full_kv: Optional[bool] = False, enable_mla: Optional[bool] = False, mla_absorption_config: Optional[bool] = False, + mdp_ts_num_devices: Optional[int] = 1, ): kwargs = { "offload_pt_weights": offload_pt_weights, @@ -355,6 +355,7 @@ def get_onnx_path( "retain_full_kv": retain_full_kv, "enable_mla": enable_mla, "mla_absorption_config": mla_absorption_config, + "mdp_ts_num_devices": mdp_ts_num_devices, } if prefill_only: @@ -491,6 +492,7 @@ def _compile( retain_full_kv, enable_mla, mla_absorption_config, + mdp_ts_num_devices, ) ) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 97cb207218..ab227c8405 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -270,6 +270,14 @@ def fused_forward( n_head_ckv = compressed_kv.shape[1] p = self.num_heads//n_head_ckv + blocking_config = getattr(self, "attn_blocking_config", None) + num_kv_blocks = 1 + if blocking_config is not None: + num_kv_blocks = blocking_config.num_kv_blocks + print("num_kv_blocks : ", num_kv_blocks) + ctx_len = compressed_kv.shape[-2] + block_size = -(-ctx_len // num_kv_blocks) + value_out = [] for i in range(n_head_ckv): value_states_ph = torch.matmul(kva[:,i,:,:], self.v_up[:, :, i*p*self.v_head_dim: (i+1)*p*self.v_head_dim]) @@ -283,18 +291,12 @@ def fused_forward( if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - blocking_config = getattr(self, "attn_blocking_config", None) - num_kv_blocks = 1 - if blocking_config is not None: - num_kv_blocks = blocking_config.num_kv_blocks - print("num_kv_blocks : ", num_kv_blocks) - seq_len = compressed_kv.shape[-2] - block_size = -(-seq_len // num_kv_blocks) - attn_output_list = [] for k in range(n_head_ckv): attn_weights_list = [] for j in range(num_kv_blocks): + kv_start_index = j*block_size + kv_end_index = min(ctx_len, (j+1)*block_size) if enable_absorption: if absorb_online: if j==0: @@ -307,14 +309,14 @@ def fused_forward( out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[k*p:(k+1)*p,:,:]) out3 = torch.cat((out2, q_pe[:,k*p:(k+1)*p,:,:]), -1) - kva_kpe = torch.cat((kva[:,k,j*block_size:(j+1)*block_size,:],k_pe[:,k,j*block_size:(j+1)*block_size,:]), -1).unsqueeze(1) + kva_kpe = torch.cat((kva[:,k,kv_start_index:kv_end_index,:],k_pe[:,k,kv_start_index:kv_end_index,:]), -1).unsqueeze(1) attn_weights = torch.matmul(out3, kva_kpe.transpose(2,3)) * self.softmax_scale else: if j==0: print("no absorption") k_nope = torch.matmul(kva[:,k,:,:], self.k_up[:, :, k*p*self.qk_nope_head_dim: (k+1)*p*self.qk_nope_head_dim]) k_nope = k_nope.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) - key_states = torch.cat((k_nope[:,:,j*block_size:(j+1)*block_size,:], k_pe[:,k,j*block_size:(j+1)*block_size,:].unsqueeze(1).repeat(1,p,1,1)), -1) + key_states = torch.cat((k_nope[:,:,kv_start_index:kv_end_index,:], k_pe[:,k,kv_start_index:kv_end_index,:].unsqueeze(1).repeat(1,p,1,1)), -1) query_states = torch.cat((q_nope[:,k*p:(k+1)*p,:,:], q_pe[:,k*p:(k+1)*p,:,:]), -1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8bb0851e2d..939a883a42 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3186,13 +3186,14 @@ def export( qaic_config=self.model.qaic_config, ) if enable_mla: + mdp_ts_num_devices = kwargs.get("mdp_ts_num_devices", 1) example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} output_names = [v for v in output_names if "past" not in v] example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): - ckv = torch.zeros((bs, 4, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) - k_pe = torch.zeros((bs, 4, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) + ckv = torch.zeros((bs, mdp_ts_num_devices, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) + k_pe = torch.zeros((bs, mdp_ts_num_devices, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) example_inputs["compressed_kvs"][i].append(ckv) example_inputs["compressed_kvs"][i].append(k_pe) dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} From 2859c196774213e2ba0b5481a3ad232fb9cb5cd2 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 1 Apr 2026 16:49:23 +0000 Subject: [PATCH 15/51] Add example script Signed-off-by: Mamta Singh --- examples/export_kimik2.py | 41 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/export_kimik2.py b/examples/export_kimik2.py index fb8a724696..b9c461281d 100644 --- a/examples/export_kimik2.py +++ b/examples/export_kimik2.py @@ -1,40 +1,39 @@ import torch - -torch.set_printoptions( - precision=4, - edgeitems=2, - threshold=50, # max number of elements printed - linewidth=120, -) from transformers import AutoModelForCausalLM, AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM +#parameters to be configured +TS=4 +num_hidden_layers=2 +enable_mla=True +mla_absorption_config={"enable": True, "online": False} +prefill_seq_len = 1 +ctx_len = 2048 +qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} + + model = AutoModelForCausalLM.from_pretrained( - "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", + "moonshotai/Kimi-K2-Thinking", torch_dtype=torch.float32, - num_hidden_layers=2, + num_hidden_layers=num_hidden_layers, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -qeff_model = QEFFAutoModelForCausalLM(model) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat = TS) -onnx_path = qeff_model.export( - prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable": True, "online": False} -) -print(onnx_path) qpc_path = qeff_model.compile( - prefill_seq_len=1, - ctx_len=128, - enable_mla=True, - mla_absorption_config={"enable": True, "online": False}, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + enable_mla=enable_mla, + mla_absorption_config=mla_absorption_config, mxfp6_matmul=True, mxint8_kv_cache=False, - num_devices=4, + num_devices=TS, num_cores=16, + #prefill_only=True, + qaic_config=qaic_config, ) -print(qpc_path) -prompts = "Once upon a time," qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 13da8ef5d9d94d374fdd31e3528138992851e78e Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 7 Apr 2026 11:44:07 +0000 Subject: [PATCH 16/51] update fused forward for loops Signed-off-by: Mamta Singh --- .../deepseek_v3/modeling_deepseek_qeff.py | 159 +++++++++++------- 1 file changed, 97 insertions(+), 62 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index ab227c8405..a0fe1592e0 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -219,12 +219,14 @@ def __qeff_init__( self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) - fusedqk_list = [] + '''fusedqk_list = [] for i in range(self.num_heads): - fusedqk_list.append(torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])) + fusedqk_list.append(torch.matmul(per_head_q_up[i, :, :], per_head_k_up[i, :, :])) fusedqk = torch.cat(fusedqk_list, dim=0) fusedqk = fusedqk.reshape(self.num_heads, -1, self.kv_lora_rank) + ''' + fusedqk = torch.bmm(per_head_q_up, per_head_k_up) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) def fused_forward( @@ -245,7 +247,7 @@ def fused_forward( bsz, q_len, _ = hidden_states.size() compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) @@ -268,76 +270,109 @@ def fused_forward( enable_absorption = False n_head_ckv = compressed_kv.shape[1] - p = self.num_heads//n_head_ckv - - blocking_config = getattr(self, "attn_blocking_config", None) - num_kv_blocks = 1 - if blocking_config is not None: - num_kv_blocks = blocking_config.num_kv_blocks - print("num_kv_blocks : ", num_kv_blocks) - ctx_len = compressed_kv.shape[-2] - block_size = -(-ctx_len // num_kv_blocks) - - value_out = [] - for i in range(n_head_ckv): - value_states_ph = torch.matmul(kva[:,i,:,:], self.v_up[:, :, i*p*self.v_head_dim: (i+1)*p*self.v_head_dim]) - value_states_ph = value_states_ph.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) - value_out.append(value_states_ph) - value_states = torch.cat(value_out, dim=1) - - cos, sin = self.rotary_emb(value_states_ph, seq_len=32 * 1024) + p = self.num_heads // n_head_ckv + + ############################################################################ + + + ''' + kva_expanded = kva.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.kv_lora_rank) + v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) + + value_states=torch.matmul(kva_expanded, v_up_per_head) + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - attn_output_list = [] - for k in range(n_head_ckv): - attn_weights_list = [] - for j in range(num_kv_blocks): - kv_start_index = j*block_size - kv_end_index = min(ctx_len, (j+1)*block_size) - if enable_absorption: - if absorb_online: - if j==0: - print("online absorption") - out = torch.matmul(self.per_head_q_up[k*p:(k+1)*p,:,:], self.per_head_k_up[k*p:(k+1)*p,:,:]) - out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out) - else: - if j==0: - print("using fused qk") - out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[k*p:(k+1)*p,:,:]) - - out3 = torch.cat((out2, q_pe[:,k*p:(k+1)*p,:,:]), -1) - kva_kpe = torch.cat((kva[:,k,kv_start_index:kv_end_index,:],k_pe[:,k,kv_start_index:kv_end_index,:]), -1).unsqueeze(1) - attn_weights = torch.matmul(out3, kva_kpe.transpose(2,3)) * self.softmax_scale - else: - if j==0: - print("no absorption") - k_nope = torch.matmul(kva[:,k,:,:], self.k_up[:, :, k*p*self.qk_nope_head_dim: (k+1)*p*self.qk_nope_head_dim]) - k_nope = k_nope.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2) - key_states = torch.cat((k_nope[:,:,kv_start_index:kv_end_index,:], k_pe[:,k,kv_start_index:kv_end_index,:].unsqueeze(1).repeat(1,p,1,1)), -1) - query_states = torch.cat((q_nope[:,k*p:(k+1)*p,:,:], q_pe[:,k*p:(k+1)*p,:,:]), -1) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - attn_weights_list.append(attn_weights) - - attn_weights = torch.cat(attn_weights_list, dim=-1) - - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, value_states[:,k*p:(k+1)*p,:,:]) - attn_output_list.append(attn_output) - - attn_output = torch.cat(attn_output_list, dim=1) + k_pe_expanded = k_pe.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) + + if enable_absorption: + if absorb_online: + print("online absorption") + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + print("using fused qk") + #breakpoint() + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) + + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + else: + print("no absorption") + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) + + k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ''' + + + + + kva_expanded = kva.transpose(1,0) #1,4,128,512 -> (4,1,128,512) + v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) #64,512,128 + + v_up_per_head = v_up_per_head.reshape(-1,p, self.kv_lora_rank, self.v_head_dim) #4,16,512,128 + + value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, self.v_head_dim, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) #1,4,128,64 + + k_pe_expanded = k_pe.transpose(1,0) #reshape(-1, bsz, self.qk_nope_head_dim, self.qk_rope_head_dim) #4,1,128,64 + + if enable_absorption: + if absorb_online: + print("online absorption") + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + print("using fused qk") + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) #1,1,32,1536 , 64,1536,512 -> 1,64,32,512 + + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1).reshape(4,16,32,576) #1,64,32,512, 1,64,32,64 -> 1,64,32,576 -> 4,16,32,576 + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) #4,1,128,512, 4,1,128,64 -> 4,1,128,576 + else: + print("no absorption") + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) + + k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + attn_weights = attn_weights.reshape(1, self.num_heads, 32, 128) + + + + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, compressed_kvs, value_states + def forward( self, hidden_states: torch.Tensor, From 9e432dcfe41f108ff126ed6db7f53209d43ea7c4 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 7 Apr 2026 19:02:53 +0000 Subject: [PATCH 17/51] fix reshapes in fused_forward Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek_qeff.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index a0fe1592e0..d058820b5a 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -322,7 +322,7 @@ def fused_forward( v_up_per_head = v_up_per_head.reshape(-1,p, self.kv_lora_rank, self.v_head_dim) #4,16,512,128 - value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, self.v_head_dim, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 + value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, -1, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) @@ -341,8 +341,9 @@ def fused_forward( print("using fused qk") q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) #1,1,32,1536 , 64,1536,512 -> 1,64,32,512 - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1).reshape(4,16,32,576) #1,64,32,512, 1,64,32,64 -> 1,64,32,576 -> 4,16,32,576 - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) #4,1,128,512, 4,1,128,64 -> 4,1,128,576 + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) #1,64,32,512, 1,64,32,64 -> 1,64,32,576 -> + query_states = query_states.reshape(-1,p,q_len,self.kv_lora_rank + self.qk_rope_head_dim) #1,64,32,576 -> 4,16,32,576 + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) #4,1,128,512, 4,1,128,64 -> 4,1,128,576 else: print("no absorption") q_nope = torch.bmm(q_a_proj_out, self.q_up) @@ -354,7 +355,7 @@ def fused_forward( key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - attn_weights = attn_weights.reshape(1, self.num_heads, 32, 128) + attn_weights = attn_weights.reshape(bsz, self.num_heads, q_len, -1) From b62511389f02e949221fb8cddb0ee9e23f2ace3d Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 7 Apr 2026 21:48:17 +0000 Subject: [PATCH 18/51] write_only and read_onnly for blocked_kv Signed-off-by: Mamta Singh --- QEfficient/customop/ctx_scatter_gather.py | 1 + QEfficient/transformers/cache_utils.py | 74 ++++ .../deepseek_v3/modeling_deepseek_qeff.py | 360 ++++++++++++++---- 3 files changed, 368 insertions(+), 67 deletions(-) diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 7b15effe76..59bfe6af03 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -163,6 +163,7 @@ class CtxGatherFuncBlockedKV(torch.autograd.Function): def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 97e4c56c02..8e478528f3 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -390,6 +390,66 @@ def update_k_pe(self, k_pe_cache, cache_kwargs): k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out) return k_pe_out + def read_only_blocked_ckv(self, start_index, end_index, cache_kwargs): + # Gather + ckv_out = self.ckv + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + batch, num_kv_heads, _, _ = ckv_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + ckv_out = CtxGatherFuncBlockedKV.apply(ckv_out, ctx_indices) + + ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) + return ckv_out + + def read_only_blocked_k_pe(self, start_index, end_index, cache_kwargs): + # Gather + k_pe_out = self.k_pe + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + batch, num_kv_heads, _, _ = k_pe_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + k_pe_out = CtxGatherFuncBlockedKV.apply(k_pe_out, ctx_indices) + + k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out) + return k_pe_out + + def write_only_k_pe(self, k_pe_cache, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later + + self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) + return self.k_pe + + def write_only_ckv(self, compressed_kv, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later + + self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv) + return self.ckv + class QEffDynamicCompressedKVRopeCache: def __init__( @@ -406,6 +466,20 @@ def update_ckv(self, ckv, layer_idx, cache_kwargs): def update_k_pe(self, k_pe, layer_idx, cache_kwargs): return self.layers[layer_idx].update_k_pe(k_pe, cache_kwargs) + def read_only_blocked_ckv(self, start_index, end_index, layer_idx, cache_kwargs): + return self.layers[layer_idx].read_only_blocked_ckv(start_index, end_index, cache_kwargs) + + def read_only_blocked_k_pe(self, start_index, end_index, layer_idx, cache_kwargs): + return self.layers[layer_idx].read_only_blocked_k_pe(start_index, end_index, cache_kwargs) + + def write_only_ckv(self, ckv, layer_idx, cache_kwargs): + #self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only_ckv(ckv, cache_kwargs) + + def write_only_k_pe(self, k_pe, layer_idx, cache_kwargs): + #self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only_k_pe(k_pe, cache_kwargs) + @classmethod def from_legacy_cache(cls, past_key_values): cache = cls() diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index d058820b5a..e0e2fc01f4 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -189,6 +189,7 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed +''' class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" @@ -198,35 +199,37 @@ def __qeff_init__( q_up, q_rope = self.q_b_proj.weight.T.view( -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - self.q_up = torch.nn.Parameter(q_up.detach().clone()) + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) - self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) - self.k_up = torch.nn.Parameter(k_up.detach().clone()) - self.v_up = torch.nn.Parameter(v_up.detach().clone()) + self.k_up = torch.nn.Parameter(k_up.detach()) + self.v_up = torch.nn.Parameter(v_up.detach()) per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) per_head_k_up = ( self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) ) - self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) - self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) + per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) + self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) + self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) + self.per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) - '''fusedqk_list = [] + fusedqk_list = [] for i in range(self.num_heads): - fusedqk_list.append(torch.matmul(per_head_q_up[i, :, :], per_head_k_up[i, :, :])) + fusedqk_list.append(torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])) fusedqk = torch.cat(fusedqk_list, dim=0) - fusedqk = fusedqk.reshape(self.num_heads, -1, self.kv_lora_rank) - ''' + fusedqk = fusedqk.reshape(1, self.num_heads, -1, self.kv_lora_rank) - fusedqk = torch.bmm(per_head_q_up, per_head_k_up) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) def fused_forward( @@ -247,15 +250,14 @@ def fused_forward( bsz, q_len, _ = hidden_states.size() compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2) compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: @@ -270,50 +272,247 @@ def fused_forward( enable_absorption = False n_head_ckv = compressed_kv.shape[1] - p = self.num_heads // n_head_ckv + p = self.num_heads//n_head_ckv - ############################################################################ - - - ''' - kva_expanded = kva.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.kv_lora_rank) - v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) - value_states=torch.matmul(kva_expanded, v_up_per_head) - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - k_pe_expanded = k_pe.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) - if enable_absorption: - if absorb_online: - print("online absorption") - out = torch.matmul(self.per_head_q_up, self.per_head_k_up) - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) - else: - print("using fused qk") - #breakpoint() - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) + attn_output_list = [] + attn_weights_list = [] + for head_block_idx in range(self.num_heads//n_head_ckv): + h_start = head_block_idx * n_head_ckv + h_end = min(h_start+n_head_ckv, self.num_heads) + + if enable_absorption: + if absorb_online: + qup_kupT = torch.matmul(self.per_head_q_up[:, h_start:h_end,:,:], self.per_head_k_up[:, h_start:h_end,:,:]) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk[:, h_start:h_end,:,:]) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) + krope_nope = torch.cat((kva, k_pe), dim=-1) + attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qkupTrope_nope.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) + + else: + knope = torch.matmul(kva, self.per_head_k_up_normal[:, h_start:h_end, :, :]) + krope_nope = torch.cat((knope, k_pe), dim=-1) + qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) + attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) + + attn_output = torch.cat(attn_output_list, dim=1) + attn_weights = torch.cat(attn_weights_list, dim=1) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, compressed_kvs, None - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) else: - print("no absorption") - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - query_states = torch.cat((q_nope, q_pe), dim=-1) + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(bsz, q_len, -1, 576)[:, :, 0, :].reshape(bsz, q_len, 576) + + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) - k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) - k_nope = torch.matmul(kva_expanded, k_up_per_head) - key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) + key_states = torch.cat((k_nope, k_pe_new), -1) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - ''' + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value, value_states + +''' +class QEffDeepseekV3Attention(nn.Module): + """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" + + def __qeff_init__( + self, + ): + q_up, q_rope = self.q_b_proj.weight.T.view( + -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim + ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + + self.q_up = torch.nn.Parameter(q_up.detach().clone()) + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) + + self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) + + self.k_up = torch.nn.Parameter(k_up.detach().clone()) + self.v_up = torch.nn.Parameter(v_up.detach().clone()) + per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + per_head_k_up = ( + self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) + ) + self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) + self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) + + fusedqk = torch.bmm(per_head_q_up, per_head_k_up) + self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) + + def fused_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + + compressed_kv = self.kv_a_layernorm(compressed_kv) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + #compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kvs.write_only_ckv(compressed_kv, self.layer_idx, cache_kwargs) + + kva = compressed_kv + + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) + else: + enable_absorption = False + + n_head_ckv = compressed_kv.shape[1] + p = self.num_heads // n_head_ckv + + ############################################################################ + + +# kva_expanded = kva.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.kv_lora_rank) +# v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) +# +# value_states=torch.matmul(kva_expanded, v_up_per_head) +# +# cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) +# q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) +# +# if compressed_kvs is not None: +# k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) +# +# k_pe_expanded = k_pe.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) +# +# if enable_absorption: +# if absorb_online: +# print("online absorption") +# out = torch.matmul(self.per_head_q_up, self.per_head_k_up) +# q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) +# else: +# print("using fused qk") +# #breakpoint() +# q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) +# +# query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) +# key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) +# else: +# print("no absorption") +# q_nope = torch.bmm(q_a_proj_out, self.q_up) +# q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) +# query_states = torch.cat((q_nope, q_pe), dim=-1) +# +# k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) +# k_nope = torch.matmul(kva_expanded, k_up_per_head) +# key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) +# +# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + + + num_kv_blocks = 8 + print("num_kv_blocks : ", num_kv_blocks) + ctx_len = compressed_kv.shape[-2] + block_size = -(-ctx_len // num_kv_blocks) + + @@ -324,40 +523,66 @@ def fused_forward( value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, -1, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) #1,4,128,64 + #k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) #1,4,128,64 + k_pe = compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) - k_pe_expanded = k_pe.transpose(1,0) #reshape(-1, bsz, self.qk_nope_head_dim, self.qk_rope_head_dim) #4,1,128,64 + #k_pe_expanded = k_pe.transpose(1,0) #reshape(-1, bsz, self.qk_nope_head_dim, self.qk_rope_head_dim) #4,1,128,64 - if enable_absorption: - if absorb_online: - print("online absorption") - out = torch.matmul(self.per_head_q_up, self.per_head_k_up) - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) - else: - print("using fused qk") - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) #1,1,32,1536 , 64,1536,512 -> 1,64,32,512 - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) #1,64,32,512, 1,64,32,64 -> 1,64,32,576 -> - query_states = query_states.reshape(-1,p,q_len,self.kv_lora_rank + self.qk_rope_head_dim) #1,64,32,576 -> 4,16,32,576 - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) #4,1,128,512, 4,1,128,64 -> 4,1,128,576 - else: - print("no absorption") - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - query_states = torch.cat((q_nope, q_pe), dim=-1) - k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) - k_nope = torch.matmul(kva_expanded, k_up_per_head) - key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - attn_weights = attn_weights.reshape(bsz, self.num_heads, q_len, -1) + attn_weights_list=[] + for j in range(num_kv_blocks): + kv_start_index = j * block_size + kv_end_index = min(ctx_len, (j + 1) * block_size) + #kva_expanded = kva.transpose(1,0) #1,4,128,512 -> (4,1,128,512) + #v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) #64,512,128 + #v_up_per_head = v_up_per_head.reshape(-1,p, self.kv_lora_rank, self.v_head_dim) #4,16,512,128 + #value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, -1, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 + + + kva = compressed_kvs.read_only_blocked_ckv(kv_start_index, kv_end_index, self.layer_idx, cache_kwargs) + kva_expanded = kva.transpose(1,0) + k_pe = compressed_kvs.read_only_blocked_k_pe(kv_start_index, kv_end_index, self.layer_idx, cache_kwargs) + k_pe_expanded = k_pe.transpose(1,0) + + + + + if enable_absorption: + if absorb_online: + print("online absorption") + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + print("using fused qk") + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) #1,1,32,1536 , 64,1536,512 -> 1,64,32,512 + + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) #1,64,32,512, 1,64,32,64 -> 1,64,32,576 -> + query_states = query_states.reshape(-1,p,q_len,self.kv_lora_rank + self.qk_rope_head_dim) #1,64,32,576 -> 4,16,32,576 + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) #4,1,128,512, 4,1,128,64 -> 4,1,128,576 + else: + print("no absorption") + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) + + k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + attn_weights = attn_weights.reshape(bsz, self.num_heads, q_len, -1) + + attn_weights_list.append(attn_weights) + + attn_weights = torch.cat(attn_weights_list, dim=-1) if attention_mask is not None: # no matter the length, we just slice it @@ -432,6 +657,7 @@ def forward( return attn_output, attn_weights, past_key_value, value_states + class QEffDeepseekV3MoE(nn.Module): def __qeff_init__( self, From 768ce8544b7adc89de4ba76138831f6bcdfe5962 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 8 Apr 2026 08:13:53 +0000 Subject: [PATCH 19/51] removed commented code Signed-off-by: Onkar Chougule --- .../deepseek_v3/modeling_deepseek_qeff.py | 204 ++++++++++++------ 1 file changed, 140 insertions(+), 64 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index e0e2fc01f4..ed89f0fd2c 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -391,7 +391,144 @@ def forward( ''' class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" + def __qeff_init___for_h_blocking( + self, + ): + q_up, q_rope = self.q_b_proj.weight.T.view( + -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim + ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + self.q_up = torch.nn.Parameter(q_up.detach().clone()) + + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) + self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) + + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) + + self.k_up = torch.nn.Parameter(k_up.detach()) + self.v_up = torch.nn.Parameter(v_up.detach()) + per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + per_head_k_up = ( + self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) + ) + per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) + self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) + self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) + self.per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) + + fusedqk_list = [] + for i in range(self.num_heads): + fusedqk_list.append(torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])) + fusedqk = torch.cat(fusedqk_list, dim=0) + fusedqk = fusedqk.reshape(1, self.num_heads, -1, self.kv_lora_rank) + + self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) + + def fused_forward_h_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2) + compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + + compressed_kv = self.kv_a_layernorm(compressed_kv) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) + + kva = compressed_kv + + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) + else: + enable_absorption = False + + n_head_ckv = compressed_kv.shape[1] + p = self.num_heads//n_head_ckv + + + + cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + + attn_output_list = [] + attn_weights_list = [] + for head_block_idx in range(self.num_heads//n_head_ckv): + h_start = head_block_idx * n_head_ckv + h_end = min(h_start+n_head_ckv, self.num_heads) + + if enable_absorption: + if absorb_online: + qup_kupT = torch.matmul(self.per_head_q_up[:, h_start:h_end,:,:], self.per_head_k_up[:, h_start:h_end,:,:]) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk[:, h_start:h_end,:,:]) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) + krope_nope = torch.cat((kva, k_pe), dim=-1) + attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qkupTrope_nope.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) + + else: + knope = torch.matmul(kva, self.per_head_k_up_normal[:, h_start:h_end, :, :]) + krope_nope = torch.cat((knope, k_pe), dim=-1) + qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) + attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) + + attn_output = torch.cat(attn_output_list, dim=1) + attn_weights = torch.cat(attn_weights_list, dim=1) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, compressed_kvs, None + def __qeff_init__( self, ): @@ -444,15 +581,15 @@ def fused_forward( compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = torch.matmul(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: - #compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) compressed_kv = compressed_kvs.write_only_ckv(compressed_kv, self.layer_idx, cache_kwargs) kva = compressed_kv @@ -466,56 +603,11 @@ def fused_forward( n_head_ckv = compressed_kv.shape[1] p = self.num_heads // n_head_ckv - ############################################################################ - - -# kva_expanded = kva.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.kv_lora_rank) -# v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) -# -# value_states=torch.matmul(kva_expanded, v_up_per_head) -# -# cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) -# q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) -# -# if compressed_kvs is not None: -# k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) -# -# k_pe_expanded = k_pe.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) -# -# if enable_absorption: -# if absorb_online: -# print("online absorption") -# out = torch.matmul(self.per_head_q_up, self.per_head_k_up) -# q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) -# else: -# print("using fused qk") -# #breakpoint() -# q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) -# -# query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) -# key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) -# else: -# print("no absorption") -# q_nope = torch.bmm(q_a_proj_out, self.q_up) -# q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) -# query_states = torch.cat((q_nope, q_pe), dim=-1) -# -# k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) -# k_nope = torch.matmul(kva_expanded, k_up_per_head) -# key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) -# -# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - num_kv_blocks = 8 print("num_kv_blocks : ", num_kv_blocks) ctx_len = compressed_kv.shape[-2] block_size = -(-ctx_len // num_kv_blocks) - - - - kva_expanded = kva.transpose(1,0) #1,4,128,512 -> (4,1,128,512) v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) #64,512,128 @@ -527,34 +619,18 @@ def fused_forward( q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - #k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) #1,4,128,64 k_pe = compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) - #k_pe_expanded = k_pe.transpose(1,0) #reshape(-1, bsz, self.qk_nope_head_dim, self.qk_rope_head_dim) #4,1,128,64 - - - - attn_weights_list=[] for j in range(num_kv_blocks): kv_start_index = j * block_size kv_end_index = min(ctx_len, (j + 1) * block_size) - - #kva_expanded = kva.transpose(1,0) #1,4,128,512 -> (4,1,128,512) - #v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) #64,512,128 - #v_up_per_head = v_up_per_head.reshape(-1,p, self.kv_lora_rank, self.v_head_dim) #4,16,512,128 - #value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, -1, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 - - kva = compressed_kvs.read_only_blocked_ckv(kv_start_index, kv_end_index, self.layer_idx, cache_kwargs) kva_expanded = kva.transpose(1,0) k_pe = compressed_kvs.read_only_blocked_k_pe(kv_start_index, kv_end_index, self.layer_idx, cache_kwargs) k_pe_expanded = k_pe.transpose(1,0) - - - if enable_absorption: if absorb_online: print("online absorption") From 4f11f58a483a4bece26316db0acaa491e08a5024 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 8 Apr 2026 18:31:08 +0000 Subject: [PATCH 20/51] update compressed-tensors and tiktoken version Signed-off-by: Mamta Singh --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e5060f5c4..4342e5cc03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,8 @@ dependencies = [ "ftfy==6.3.1", "imageio==2.37.2", "imageio-ffmpeg==0.6.0", - "tiktoken", - "compressed-tensors", + "tiktoken==0.12.0", + "compressed-tensors==0.14.0", "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", From 4cbb3763405c8cfb6bf202870ad4cd64b3ad80e8 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Thu, 9 Apr 2026 07:09:43 +0000 Subject: [PATCH 21/51] add all forward for h and kv blocking Signed-off-by: Mamta Singh --- .../deepseek_v3/modeling_deepseek_qeff.py | 486 +++++++++--------- .../transformers/models/pytorch_transforms.py | 3 + 2 files changed, 241 insertions(+), 248 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index ed89f0fd2c..4a7338e6aa 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -1,5 +1,5 @@ import math -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union, Any import os import torch @@ -189,10 +189,57 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -''' +def update_running_softmax( + current_max: torch.Tensor, + attn_weights_block: torch.Tensor, + current_denominator: torch.Tensor, + output: torch.Tensor, + v_block: torch.Tensor, + skip_kv: bool = False, + skip_future: Optional[torch.Tensor] = None, +): + # Update Running row maximum + prev_max = current_max + current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=3).values) + delta_max = prev_max - current_max_updated + + current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) + + # update running denominator + prev_denominator = current_denominator + curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) + current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum + + prob = current_exp / current_denominator_updated.unsqueeze(-1) + + prev_output = output + # if updating running softmax with attention sinks, we don't have v_block + if v_block is not None: + output_updated = ((prev_denominator / current_denominator_updated).unsqueeze(-1)) * prev_output * torch.exp( + delta_max.unsqueeze(-1) + ) + torch.matmul(prob, v_block) + else: + output_updated = ( + ((prev_denominator / current_denominator_updated).unsqueeze(-1)) + * prev_output + * torch.exp(delta_max.unsqueeze(-1)) + ) + + if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + current_max = torch.where(skip_future, prev_max, current_max_updated) + current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) + output = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) + else: + # Eager mode + current_max = current_max_updated + current_denominator = current_denominator_updated + output = output_updated + return current_max, current_denominator, output + + + class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" - def __qeff_init__( self, ): @@ -232,7 +279,7 @@ def __qeff_init__( self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - def fused_forward( + def fused_forward_h_blocking( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], @@ -258,6 +305,7 @@ def fused_forward( q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: @@ -272,9 +320,6 @@ def fused_forward( enable_absorption = False n_head_ckv = compressed_kv.shape[1] - p = self.num_heads//n_head_ckv - - cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) @@ -330,107 +375,7 @@ def fused_forward( return attn_output, attn_weights, compressed_kvs, None - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(bsz, q_len, -1, 576)[:, :, 0, :].reshape(bsz, q_len, 576) - - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = torch.cat((q_nope, q_pe), -1) - k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) - key_states = torch.cat((k_nope, k_pe_new), -1) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, past_key_value, value_states - -''' -class QEffDeepseekV3Attention(nn.Module): - """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" - def __qeff_init___for_h_blocking( - self, - ): - q_up, q_rope = self.q_b_proj.weight.T.view( - -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim - ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - self.q_up = torch.nn.Parameter(q_up.detach().clone()) - - q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) - self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) - - k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) - - self.k_up = torch.nn.Parameter(k_up.detach()) - self.v_up = torch.nn.Parameter(v_up.detach()) - per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - per_head_k_up = ( - self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) - ) - per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) - self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) - self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) - self.per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) - - fusedqk_list = [] - for i in range(self.num_heads): - fusedqk_list.append(torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])) - fusedqk = torch.cat(fusedqk_list, dim=0) - fusedqk = fusedqk.reshape(1, self.num_heads, -1, self.kv_lora_rank) - - self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - - def fused_forward_h_blocking( + def fused_forward_blocked_kv( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], @@ -446,23 +391,24 @@ def fused_forward_h_blocking( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + num_kv_blocks = 4 + ctx_len = compressed_kvs.layers[0].ckv.shape[2] + kv_block_size = -(-ctx_len // num_kv_blocks) + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.matmul(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = torch.matmul(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - + compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if compressed_kvs is not None: - compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) - - kva = compressed_kv + if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) @@ -470,96 +416,91 @@ def fused_forward_h_blocking( else: enable_absorption = False - n_head_ckv = compressed_kv.shape[1] - p = self.num_heads//n_head_ckv - - + ## Write Only + if compressed_kvs is not None: + compressed_kv = compressed_kvs.write_only_ckv(compressed_kv, self.layer_idx, cache_kwargs) - cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) + cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + k_pe = compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) + if enable_absorption: + if absorb_online: + qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe), dim=-1) + batch_size, num_heads, seq_len, _ = qkupTrope_nope.shape + else: + batch_size, num_heads, seq_len, _ = q_nope.shape - attn_output_list = [] - attn_weights_list = [] - for head_block_idx in range(self.num_heads//n_head_ckv): - h_start = head_block_idx * n_head_ckv - h_end = min(h_start+n_head_ckv, self.num_heads) - - if enable_absorption: - if absorb_online: - qup_kupT = torch.matmul(self.per_head_q_up[:, h_start:h_end,:,:], self.per_head_k_up[:, h_start:h_end,:,:]) - dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) - else: - dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk[:, h_start:h_end,:,:]) - qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) - krope_nope = torch.cat((kva, k_pe), dim=-1) - attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qkupTrope_nope.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) + + current_position = position_ids.max(dim=-1).values + skip_kv=True + output = torch.zeros(batch_size, self.num_heads, seq_len, self.kv_lora_rank, device=hidden_states.device) + + current_max = torch.full( + (batch_size, num_heads, seq_len), + float(MIN_MASKED_ATTENTION_VALUE), + device=hidden_states.device, + ) + current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=hidden_states.device) - else: - knope = torch.matmul(kva, self.per_head_k_up_normal[:, h_start:h_end, :, :]) - krope_nope = torch.cat((knope, k_pe), dim=-1) - qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) - attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = ctx_len - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=hidden_states.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + compressed_kv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, self.layer_idx, cache_kwargs) + k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, self.layer_idx, cache_kwargs) + + + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=end_index, + start_index=start_index, + ) + + if enable_absorption: + krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) + attn_weights_block = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] + attn_weights_block = torch.where( + causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block + ) + current_max, current_denominator, output = update_running_softmax(current_max, attn_weights_block, current_denominator, output, compressed_kv_block, skip_kv, skip_future) + # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] + else: + knope = torch.matmul(compressed_kv_block, self.per_head_k_up_normal) + breakpoint() + krope_nope = torch.cat((knope, k_pe_block.expand(-1,64,-1,-1)), dim=-1) + qrope_nope = torch.cat((q_nope, q_pe), dim=-1) + attn_weights_block = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + attn_weights_block = torch.where( + causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block + ) + current_max, current_denominator, output = update_running_softmax(current_max, attn_weights_block, current_denominator, output, compressed_kv_block, skip_kv, skip_future) - attn_output = torch.cat(attn_output_list, dim=1) - attn_weights = torch.cat(attn_weights_list, dim=1) + attn_output = torch.matmul(output, self.per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, compressed_kvs, None - - def __qeff_init__( - self, - ): - q_up, q_rope = self.q_b_proj.weight.T.view( - -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim - ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + return attn_output, None, compressed_kvs, None - self.q_up = torch.nn.Parameter(q_up.detach().clone()) - q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) - - self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) - k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) - - self.k_up = torch.nn.Parameter(k_up.detach().clone()) - self.v_up = torch.nn.Parameter(v_up.detach().clone()) - per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - per_head_k_up = ( - self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) - ) - self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone()) - self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone()) - - fusedqk = torch.bmm(per_head_q_up, per_head_k_up) - self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - - def fused_forward( + def fused_forward_basic( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], @@ -581,16 +522,15 @@ def fused_forward( compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = torch.bmm(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.matmul(q_a_proj_out, self.q_up) + q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if compressed_kvs is not None: - compressed_kv = compressed_kvs.write_only_ckv(compressed_kv, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) kva = compressed_kv @@ -603,64 +543,45 @@ def fused_forward( n_head_ckv = compressed_kv.shape[1] p = self.num_heads // n_head_ckv - num_kv_blocks = 8 - print("num_kv_blocks : ", num_kv_blocks) - ctx_len = compressed_kv.shape[-2] - block_size = -(-ctx_len // num_kv_blocks) + ############################################################################ - kva_expanded = kva.transpose(1,0) #1,4,128,512 -> (4,1,128,512) - v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) #64,512,128 - v_up_per_head = v_up_per_head.reshape(-1,p, self.kv_lora_rank, self.v_head_dim) #4,16,512,128 + kva_expanded = kva.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.kv_lora_rank) + v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) - value_states=torch.matmul(kva_expanded, v_up_per_head).reshape(bsz, self.num_heads, -1, self.v_head_dim) #4,16,128,128 -> 1,64,128,128 + value_states=torch.matmul(kva_expanded, v_up_per_head) - cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) - - attn_weights_list=[] - for j in range(num_kv_blocks): - kv_start_index = j * block_size - kv_end_index = min(ctx_len, (j + 1) * block_size) + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - kva = compressed_kvs.read_only_blocked_ckv(kv_start_index, kv_end_index, self.layer_idx, cache_kwargs) - kva_expanded = kva.transpose(1,0) - k_pe = compressed_kvs.read_only_blocked_k_pe(kv_start_index, kv_end_index, self.layer_idx, cache_kwargs) - k_pe_expanded = k_pe.transpose(1,0) + k_pe_expanded = k_pe.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) - if enable_absorption: - if absorb_online: - print("online absorption") - out = torch.matmul(self.per_head_q_up, self.per_head_k_up) - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) - else: - print("using fused qk") - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) #1,1,32,1536 , 64,1536,512 -> 1,64,32,512 - - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) #1,64,32,512, 1,64,32,64 -> 1,64,32,576 -> - query_states = query_states.reshape(-1,p,q_len,self.kv_lora_rank + self.qk_rope_head_dim) #1,64,32,576 -> 4,16,32,576 - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) #4,1,128,512, 4,1,128,64 -> 4,1,128,576 + if enable_absorption: + if absorb_online: + print("online absorption") + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) else: - print("no absorption") - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - query_states = torch.cat((q_nope, q_pe), dim=-1) - - k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) - k_nope = torch.matmul(kva_expanded, k_up_per_head) - key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - attn_weights = attn_weights.reshape(bsz, self.num_heads, q_len, -1) + print("using fused qk") + #breakpoint() + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) - attn_weights_list.append(attn_weights) - - attn_weights = torch.cat(attn_weights_list, dim=-1) + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + else: + print("no absorption") + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) + k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where( attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights @@ -673,8 +594,70 @@ def fused_forward( attn_output = self.o_proj(attn_output) return attn_output, attn_weights, compressed_kvs, value_states + + def fused_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if os.environ.get("KIMI_BLOCKING", "0") == "h": + return self.fused_forward_h_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + elif os.environ.get("KIMI_BLOCKING", "0") == "kv": + return self.fused_forward_blocked_kv( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + else: + return self.fused_forward_basic( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + def forward( self, hidden_states: torch.Tensor, @@ -779,7 +762,7 @@ def moe( expert_output = torch.bmm(hidden, down_proj) experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) experts_out = experts_out * topk_weights.unsqueeze(-1) - # final_hidden_states = experts_out.sum(dim=1) + final_hidden_states = torch.einsum("abc->ac", experts_out) return final_hidden_states.type(hidden_states.dtype) @@ -898,10 +881,10 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + orig_hidden_states = self.input_layernorm(hidden_states) if enable_mla: hidden_states, self_attn_weights, present_compressed_kvs, vs = self.self_attn.fused_forward( - hidden_states=hidden_states, + hidden_states=orig_hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings, @@ -914,10 +897,9 @@ def forward( mla_absorption=mla_absorption, **kwargs, ) - else: hidden_states, self_attn_weights, present_key_value, vs = self.self_attn( - hidden_states=hidden_states, + hidden_states=orig_hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings, @@ -1072,7 +1054,15 @@ def forward( class QEffDeepseekV3ForCausalLM(nn.Module): """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" - + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.layers[0].__class__} + def forward( self, input_ids: torch.LongTensor = None, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 39ab5010b4..cd802bd48c 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1041,6 +1041,9 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "DeepseekV3Attention": { "forward": QEffDeepseekV3Attention.forward, "fused_forward": QEffDeepseekV3Attention.fused_forward, + "fused_forward_blocked_kv": QEffDeepseekV3Attention.fused_forward_blocked_kv, + "fused_forward_basic": QEffDeepseekV3Attention.fused_forward_basic, + "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, "DeepseekV3RMSNorm": { From 04fa702ad2ddb157e67dca7da82bf757ea4bd752 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Thu, 9 Apr 2026 07:44:43 +0000 Subject: [PATCH 22/51] removing blocking config changes Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 65 -- QEfficient/blocking/attention_blocking.py | 126 ---- .../blocking/blocked_attention_forwards.py | 576 ------------------ QEfficient/blocking/blocking_configurator.py | 259 -------- .../deepseek_v3/modeling_deepseek_qeff.py | 1 - .../transformers/models/modeling_auto.py | 58 +- .../transformers/models/pytorch_transforms.py | 55 +- QEfficient/utils/__init__.py | 2 - QEfficient/utils/_utils.py | 17 - QEfficient/utils/constants.py | 4 +- QEfficient/utils/export_utils.py | 1 - QEfficient/utils/hash_utils.py | 8 - examples/run_kimik2.py | 8 +- 13 files changed, 27 insertions(+), 1153 deletions(-) delete mode 100644 QEfficient/blocking/attention_blocking.py delete mode 100644 QEfficient/blocking/blocked_attention_forwards.py delete mode 100644 QEfficient/blocking/blocking_configurator.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index d524f65f33..2a7cb85e94 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -26,19 +26,15 @@ ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile -from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.transformers.models.pytorch_transforms import BlockingAttentionTransform from QEfficient.utils import ( constants, create_json, create_model_params, dump_qconfig, generate_mdp_partition_config, - get_attr_or_key, hash_dict_params, load_json, - require_value, ) from QEfficient.utils.export_utils import export_wrapper @@ -371,46 +367,6 @@ def get_onnx_path( return self.onnx_path - def transform( - self, - ctx_len: Optional[int] = None, - seq_len: Optional[int] = None, - bs: Optional[int] = 1, - num_devices: int = 1, - qaic_config: Optional[dict] = None, - **compiler_options, - ): - # Apply the transformations that are dependent on compilation parameters - - qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - - if getattr(self.model, "config", None) or getattr(self.model.model, "config", None): - blocking_config = build_transformer_blocking_config_for_transform( - getattr(self.model, "config", None) - if getattr(self.model, "config", None) - else getattr(self.model.model, "config", None), - ctx_len=ctx_len, - seq_len=seq_len, - bs=bs, - num_devices=num_devices, - qaic_config=qaic_config, - **compiler_options, - ) - else: - # without a model config, this is not a model that is possible to block - blocking_config = None - - if blocking_config is not None: - self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config) - blocking_kwargs = self.hash_params.setdefault("blocking_kwargs", {}) - if blocking_config.num_kv_blocks: - blocking_kwargs["num_kv_blocks"] = blocking_config.num_kv_blocks - if blocking_config.num_q_blocks: - blocking_kwargs["num_q_blocks"] = blocking_config.num_q_blocks - if blocking_config.head_block_size: - blocking_kwargs["head_block_size"] = blocking_config.head_block_size - - @dump_qconfig def _compile( self, @@ -429,10 +385,6 @@ def _compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, - disable_blocking: Optional[bool] = True, - blocking_mode: Optional[str] = "hqkv", - vtcm_ratio: Optional[float] = 0.75, - qaic_config: Optional[dict] = None, enable_mla: Optional[bool] = False, mla_absorption_config: Optional[Dict[str, bool]] = False, **compiler_options, @@ -461,23 +413,6 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - # Transform before export - qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - bs = require_value(get_attr_or_key(specializations[0], ("batch_size", "batch")), "batch size") - seq_len = get_attr_or_key(specializations[0], ("cl", "seq_len", "sequence_length")) - ctx_len = get_attr_or_key(specializations[0], ("ctx_len", "context_length")) - self.transform( - ctx_len=ctx_len, - seq_len=seq_len, - bs=bs, - num_devices=mdp_ts_num_devices, - disable_blocking=disable_blocking, - blocking_mode=blocking_mode, - vtcm_ratio=vtcm_ratio, - qaic_config=qaic_config, - **compiler_options, - ) - onnx_path = Path( onnx_path if onnx_path diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py deleted file mode 100644 index da530d6c8b..0000000000 --- a/QEfficient/blocking/attention_blocking.py +++ /dev/null @@ -1,126 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from __future__ import annotations - -from dataclasses import dataclass -from enum import Enum -from typing import Callable, Dict, Optional - -import torch -from transformers.cache_utils import Cache - -from QEfficient.blocking.blocked_attention_forwards import ( - blocked_h_attention_forward, - blocked_hqkv_attention_forward, - blocked_kv_attention_forward, - blocked_q_attention_forward, - blocked_qkv_attention_forward, - invalid_blocking_attention_forward, -) - - -class BlockingMode(str, Enum): - NONE = "" - KV = "kv" - Q = "q" - H = "h" - QKV = "qkv" - HQKV = "hqkv" - - -@dataclass -class AttentionBlockingConfig: - mode: BlockingMode = BlockingMode.NONE - num_kv_blocks: Optional[int] = None - num_q_blocks: Optional[int] = None - head_block_size: Optional[int] = None - - -def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: - return past_key_value is not None and hasattr(past_key_value, "read_only_blockedKV") - - -_STRATEGIES: Dict[BlockingMode, Callable] = { - BlockingMode.NONE: invalid_blocking_attention_forward, - BlockingMode.KV: blocked_kv_attention_forward, - BlockingMode.Q: blocked_q_attention_forward, - BlockingMode.H: blocked_h_attention_forward, - BlockingMode.QKV: blocked_qkv_attention_forward, - BlockingMode.HQKV: blocked_hqkv_attention_forward, -} - - -def get_blocking_strategy(config: AttentionBlockingConfig) -> Callable: - return _STRATEGIES.get(config.mode, _STRATEGIES[BlockingMode.NONE]) - - -def generic_blocked_attention_interface( - module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - layer_idx: int, - past_key_value: Cache, - blocking_config: AttentionBlockingConfig, - comp_ctx_lengths: Optional[torch.LongTensor] = None, - batch_index: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_seen_tokens: Optional[int] = None, - non_blocked_forward: Callable = None, - **kwargs, -): - use_kv_blocked = ( - blocking_config is not None and "kv" in blocking_config.mode and supports_blocked_kv(past_key_value) - ) - use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) - - if past_key_value is not None: - if use_kv_blocked: - cache_kwargs = { - "batch_index": batch_index, - "position_ids": position_ids, - "past_seen_tokens": past_seen_tokens, - } - past_key_value.write_only(key, value, module.layer_idx, cache_kwargs) - else: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key, value = past_key_value.update(key, value, module.layer_idx, cache_kwargs) - - if use_blocking: - strategy = get_blocking_strategy(blocking_config) - attn_output, attn_weights = strategy( - module=module, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - scaling=scaling, - cache_kwargs=cache_kwargs, - layer_idx=layer_idx, - past_key_value=past_key_value, - num_kv_blocks=blocking_config.num_kv_blocks, - num_q_blocks=blocking_config.num_q_blocks, - head_block_size=blocking_config.head_block_size, - ) - else: - attn_output, attn_weights = non_blocked_forward( - module, - query, - key, - value, - attention_mask, - scaling=scaling, - **kwargs, - ) - - return attn_output, attn_weights diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py deleted file mode 100644 index 37ae942f52..0000000000 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ /dev/null @@ -1,576 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from __future__ import annotations - -import math -from typing import Any, Callable, Dict, Optional, Tuple - -import torch -from torch import nn -from transformers.cache_utils import Cache - -from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep) for GQA. - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def _get_kv_states(module: nn.Module, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - num_kv_groups = getattr(module, "num_key_value_groups", None) - if num_kv_groups is None: - return key, value - return repeat_kv(key, num_kv_groups), repeat_kv(value, num_kv_groups) - - -def _normalize_int(value: Optional[torch.Tensor | int]) -> int: - if isinstance(value, torch.Tensor): - return int(value.item()) - return int(value) if value is not None else 0 - - -def blocked_kv_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - num_kv_blocks: int, - cache_kwargs: Dict[str, Any], - layer_idx: int, - past_key_value: Cache, - *, - score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, - use_causal_mask: bool = False, - sliding_window: Optional[int] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Initialize result tensor - output = torch.zeros_like(query) - - # Initialize Running Maximum and Denominator - batch_size, num_heads, seq_len, _ = query.shape - current_max = torch.full( - (batch_size, num_heads, seq_len), - float(MIN_MASKED_ATTENTION_VALUE), - device=query.device, - ) - current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device) - - past_seen_tokens = _normalize_int(cache_kwargs.get("past_seen_tokens")) - total_seen_tokens = past_seen_tokens + query.shape[2] - if torch.onnx.is_in_onnx_export(): - attention_mask = None - use_causal_mask = True - position_ids = cache_kwargs.get("position_ids") - num_kv_blocks = _normalize_int(num_kv_blocks) - kv_block_positions = [(i * past_seen_tokens) // num_kv_blocks for i in range(num_kv_blocks)] - masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) - - current_position = position_ids.max(dim=-1).values - - for j in range(num_kv_blocks): - start_index = kv_block_positions[j] - if j == num_kv_blocks - 1: - kv_len_block = past_seen_tokens - start_index - else: - kv_len_block = kv_block_positions[j + 1] - start_index - end_index = start_index + kv_len_block - - skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() - - k_block, v_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) - k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) - - attn_weights_block = torch.matmul(query, k_block_states.transpose(2, 3)) * scaling - if score_mod is not None: - attn_weights_block = score_mod(attn_weights_block, start_index, end_index) - - mask_block = None - if attention_mask is not None: - mask_block = attention_mask[..., start_index:end_index] - if mask_block.shape[-1] != attn_weights_block.shape[-1]: - mask_block = None - - if use_causal_mask or mask_block is None: - target_length = min(total_seen_tokens, end_index) - causal_mask_block = _create_causal_mask( - position_ids=position_ids, - target_length=target_length, - sliding_window=sliding_window, - start_index=start_index, - ) - if mask_block is None: - mask_block = causal_mask_block - else: - mask_block = mask_block.to(torch.bool) | causal_mask_block - - if mask_block is not None: - attn_weights_block = torch.where(mask_block, masked_tensor, attn_weights_block) - - # Update Running row maximum - prev_max = current_max - current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=-1).values) - delta_max = prev_max - current_max_updated - - current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) - - # update running denominator - prev_denominator = current_denominator - curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) - current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum - - prob = current_exp / current_denominator_updated.unsqueeze(-1) - - prev_output = output - output_updated = ((prev_denominator / current_denominator_updated).unsqueeze(-1)) * prev_output * torch.exp( - delta_max.unsqueeze(-1) - ) + torch.matmul(prob, v_block_states) - - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): - current_max = torch.where(skip_future, prev_max, current_max_updated) - current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) - output = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) - else: - # Eager mode - current_max = current_max_updated - current_denominator = current_denominator_updated - output = output_updated - - attn_output = output.transpose(1, 2).contiguous() - attn_weights = None - - return attn_output, attn_weights - - -def blocked_qkv_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - num_kv_blocks: int, - num_q_blocks: int, - cache_kwargs: Dict[str, Any], - layer_idx: int, - past_key_value: Cache, - *, - score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, - use_causal_mask: bool = False, - sliding_window: Optional[int] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Initialize Running Maximum and Denominator - batch_size, num_heads, seq_len, DH = query.shape - - past_seen_tokens = _normalize_int(cache_kwargs.get("past_seen_tokens")) - if torch.onnx.is_in_onnx_export(): - attention_mask = None - use_causal_mask = True - position_ids = cache_kwargs.get("position_ids") - num_kv_blocks = _normalize_int(num_kv_blocks) - num_q_blocks = max(1, _normalize_int(num_q_blocks)) - - q_block_positions = [(i * seq_len) // num_q_blocks for i in range(num_q_blocks)] - q_output_blocks = [] - q_attn_blocks = [] - - kv_block_positions = [(i * past_seen_tokens) // num_kv_blocks for i in range(num_kv_blocks)] - - masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) - - current_position = position_ids.max(dim=-1).values - - for q_block_idx in range(num_q_blocks): - q_start = q_block_positions[q_block_idx] - if q_block_idx == num_q_blocks - 1: - q_len_block = seq_len - q_start - else: - q_len_block = q_block_positions[q_block_idx + 1] - q_start - - q_block = query[:, :, q_start : q_start + q_len_block, :] - - current_max = torch.full( - (batch_size, num_heads, q_len_block), - float(MIN_MASKED_ATTENTION_VALUE), - device=query.device, - ) - current_denominator = torch.zeros(batch_size, num_heads, q_len_block, device=query.device) - output_blocks = torch.zeros((batch_size, num_heads, q_len_block, DH), device=query.device, dtype=query.dtype) - - for j in range(num_kv_blocks): - start_index = kv_block_positions[j] - if j == num_kv_blocks - 1: - kv_len_block = past_seen_tokens - start_index - else: - kv_len_block = kv_block_positions[j + 1] - start_index - end_index = start_index + kv_len_block - - skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() - - # Eager mode Only - if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): - if skip_future.item(): - break - - k_block, v_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) - k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) - - attn_weights_block = torch.matmul(q_block, k_block_states.transpose(2, 3)) * scaling - if score_mod is not None: - attn_weights_block = score_mod(attn_weights_block, start_index, end_index) - - mask_block = None - if attention_mask is not None: - mask_block = attention_mask[..., start_index:end_index] - if mask_block.shape[-1] != attn_weights_block.shape[-1]: - mask_block = None - - if use_causal_mask or mask_block is None: - # target_length = min(total_seen_tokens, end_index) - target_length = torch.where( - torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), - past_seen_tokens, - end_index, - ) - causal_mask_block = _create_causal_mask( - position_ids=position_ids, - target_length=target_length, - sliding_window=sliding_window, - start_index=start_index, - ) - if mask_block is None: - mask_block = causal_mask_block - else: - mask_block = mask_block.to(torch.bool) | causal_mask_block - - if mask_block is not None: - attn_mask_block = mask_block[:, :, q_start : q_start + q_len_block, :] - attn_weights_block = torch.where(attn_mask_block, masked_tensor, attn_weights_block) - - # Update Running row maximum - prev_max = current_max - current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=3).values) - delta_max = prev_max - current_max_updated - - current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) - - # update running denominator - prev_denominator = current_denominator - curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) - current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum - - prob = current_exp / current_denominator_updated.unsqueeze(-1) - - prev_output = output_blocks - output_updated = ((prev_denominator / current_denominator_updated).unsqueeze(-1)) * prev_output * torch.exp( - delta_max.unsqueeze(-1) - ) + torch.matmul(prob, v_block_states) - - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): - current_max = torch.where(skip_future, prev_max, current_max_updated) - current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) - output_blocks = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) - else: - # Eager mode - current_max = current_max_updated - current_denominator = current_denominator_updated - output_blocks = output_updated - q_output_blocks.append(output_blocks) - q_attn_blocks.append(attn_weights_block) - - attn_output = torch.cat(q_output_blocks, dim=2).transpose(1, 2).contiguous() - attn_weights = torch.cat(q_attn_blocks, dim=2) - - return attn_output, attn_weights - - -def blocked_hqkv_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - num_kv_blocks: int, - num_q_blocks: int, - head_block_size: int, - cache_kwargs: Dict[str, Any], - layer_idx: int, - past_key_value: Cache, - *, - score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, - use_causal_mask: bool = False, - sliding_window: Optional[int] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Initialize Running Maximum and Denominator - batch_size, num_heads, seq_len, DH = query.shape - - past_seen_tokens = _normalize_int(cache_kwargs.get("past_seen_tokens")) - if torch.onnx.is_in_onnx_export(): - attention_mask = None - use_causal_mask = True - position_ids = cache_kwargs.get("position_ids") - num_kv_blocks = _normalize_int(num_kv_blocks) - if head_block_size <= 0: - head_block_size = num_heads - num_head_blocks = math.ceil(num_heads / head_block_size) - num_q_blocks = max(1, _normalize_int(num_q_blocks)) - - q_block_positions = [(i * seq_len) // num_q_blocks for i in range(num_q_blocks)] - - h_output_blocks = [] - h_attn_blocks = [] - - kv_block_positions = [(i * past_seen_tokens) // num_kv_blocks for i in range(num_kv_blocks)] - - masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) - - current_position = position_ids.max(dim=-1).values - - # Process each head block independently - for head_block_idx in range(num_head_blocks): - h_start = head_block_idx * head_block_size - h_end = min(h_start + head_block_size, num_heads) - - # Extract head blocks - q_g = query[:, h_start:h_end, :, :] - - q_output_blocks = [] - q_attn_blocks = [] - - for q_block_idx in range(num_q_blocks): - q_start = q_block_positions[q_block_idx] - if q_block_idx == num_q_blocks - 1: - q_len_block = seq_len - q_start - else: - q_len_block = q_block_positions[q_block_idx + 1] - q_start - - q_block = q_g[:, :, q_start : q_start + q_len_block, :] - - current_max = torch.full( - (batch_size, h_end - h_start, q_len_block), - float(MIN_MASKED_ATTENTION_VALUE), - device=query.device, - ) - current_denominator = torch.zeros(batch_size, h_end - h_start, q_len_block, device=query.device) - output_blocks = torch.zeros( - (batch_size, h_end - h_start, q_len_block, DH), device=query.device, dtype=query.dtype - ) - - for j in range(num_kv_blocks): - start_index = kv_block_positions[j] - if j == num_kv_blocks - 1: - kv_len_block = past_seen_tokens - start_index - else: - kv_len_block = kv_block_positions[j + 1] - start_index - end_index = start_index + kv_len_block - - skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() - - # Eager mode Only - if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): - if skip_future.item(): - break - - k_block, v_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) - k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) - - k_g = k_block_states[:, h_start:h_end, :, :] - v_g = v_block_states[:, h_start:h_end, :, :] - - attn_weights_block = torch.matmul(q_block, k_g.transpose(2, 3)) * scaling - if score_mod is not None: - attn_weights_block = score_mod(attn_weights_block, start_index, end_index) - - mask_block = None - if attention_mask is not None: - mask_block = attention_mask[..., start_index:end_index] - if mask_block.shape[-1] != attn_weights_block.shape[-1]: - mask_block = None - - if use_causal_mask or mask_block is None: - # target_length = min(total_seen_tokens, end_index) - target_length = torch.where( - torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), - past_seen_tokens, - end_index, - ) - causal_mask_block = _create_causal_mask( - position_ids=position_ids, - target_length=target_length, - sliding_window=sliding_window, - start_index=start_index, - ) - if mask_block is None: - mask_block = causal_mask_block - else: - mask_block = mask_block.to(torch.bool) | causal_mask_block - - if mask_block is not None: - mask_block_g = mask_block[:, :, q_start : q_start + q_len_block, :] - attn_weights_block = torch.where(mask_block_g, masked_tensor, attn_weights_block) - - # Update Running row maximum - prev_max = current_max - current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=3).values) - delta_max = prev_max - current_max_updated - - current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) - - # update running denominator - prev_denominator = current_denominator - curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) - current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum - - prob = current_exp / current_denominator_updated.unsqueeze(-1) - - prev_output = output_blocks - output_updated = ( - (prev_denominator / current_denominator_updated).unsqueeze(-1) - ) * prev_output * torch.exp(delta_max.unsqueeze(-1)) + torch.matmul(prob, v_g) - - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): - # skip_mask = skip_future.view(1, 1, 1).expand(batch_size, h_end - h_start, q_len_block) - current_max = torch.where(skip_future, prev_max, current_max_updated) - current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) - output_blocks = torch.where(skip_future, prev_output, output_updated) - else: - # Eager mode - current_max = current_max_updated - current_denominator = current_denominator_updated - output_blocks = output_updated - q_output_blocks.append(output_blocks) - q_attn_blocks.append(attn_weights_block) - - head_output = torch.cat(q_output_blocks, dim=2) - head_attn_weights = torch.cat(q_attn_blocks, dim=2) - h_output_blocks.append(head_output) - h_attn_blocks.append(head_attn_weights) - - attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() - attn_weights = torch.cat(h_attn_blocks, dim=1) - - return attn_output, attn_weights - - -def blocked_h_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - head_block_size: int, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Q-blocked attention that slices the query sequence into blocks and processes each block. - """ - batch_size, num_heads, q_len, _ = query.shape - if head_block_size <= 0: - head_block_size = num_heads - num_head_blocks = math.ceil(num_heads / head_block_size) - - key_states, value_states = _get_kv_states(module, key, value) - - h_output_blocks = [] - h_attn_blocks = [] - - masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) - - # Process each head block independently - for head_block_idx in range(num_head_blocks): - h_start = head_block_idx * head_block_size - h_end = min(h_start + head_block_size, num_heads) - - # Extract head blocks - q_g = query[:, h_start:h_end, :, :] - k_g = key_states[:, h_start:h_end, :, :] - v_g = value_states[:, h_start:h_end, :, :] - - attn_weights = torch.matmul(q_g, k_g.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = torch.where(attention_mask, masked_tensor, attn_weights) - - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - output_block = torch.matmul(attn_weights, v_g) - - h_output_blocks.append(output_block) - h_attn_blocks.append(attn_weights) - - attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() - attn_weights = torch.cat(h_attn_blocks, dim=1) - - return attn_output, attn_weights - - -def blocked_q_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - num_q_blocks: int, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Q-blocked attention that slices the query sequence into blocks and processes each block. - """ - batch_size, num_heads, q_len, _ = query.shape - num_q_blocks = max(1, _normalize_int(num_q_blocks)) - key_states, value_states = _get_kv_states(module, key, value) - - q_block_positions = [(i * q_len) // num_q_blocks for i in range(num_q_blocks)] - q_output_blocks = [] - q_attn_blocks = [] - - masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=query.device) - - for q_block_idx in range(num_q_blocks): - q_start = q_block_positions[q_block_idx] - if q_block_idx == num_q_blocks - 1: - q_len_block = q_len - q_start - else: - q_len_block = q_block_positions[q_block_idx + 1] - q_start - - q_block = query[:, :, q_start : q_start + q_len_block, :] - attn_mask_block = None - if attention_mask is not None: - attn_mask_block = attention_mask[:, :, q_start : q_start + q_len_block, :] - - attn_weights = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling - if attn_mask_block is not None: - attn_weights = torch.where(attn_mask_block, masked_tensor, attn_weights) - - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - output_block = torch.matmul(attn_weights, value_states) - - q_output_blocks.append(output_block) - q_attn_blocks.append(attn_weights) - - attn_output = torch.cat(q_output_blocks, dim=2).transpose(1, 2).contiguous() - attn_weights = torch.cat(q_attn_blocks, dim=2) - - return attn_output, attn_weights - - -def invalid_blocking_attention_forward(*args, **kwargs): - raise NotImplementedError("Invalid blocking strategy was selected") diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py deleted file mode 100644 index a00a2bdc31..0000000000 --- a/QEfficient/blocking/blocking_configurator.py +++ /dev/null @@ -1,259 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- -""" -Utility helpers to suggest attention/FFN blocking configs for diffusers transformers and transformers - -This module adapts the standalone configurator script into a clean, importable API -that can be fed model config + pipeline compile config to derive blocking settings. -""" - -from __future__ import annotations - -import math -from typing import Any, Dict, List, Optional - -from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, BlockingMode -from QEfficient.utils import get_attr_or_key, require_value -from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD - - -def _infer_head_dim(model_config: Any, num_heads: int) -> int: - head_dim = get_attr_or_key(model_config, ("attention_head_dim", "head_dim", "head_dim_per_head")) - if head_dim is not None: - return int(head_dim) - hidden_size = get_attr_or_key(model_config, ("hidden_size", "d_model", "model_dim", "attention_dim")) - if hidden_size is None: - raise ValueError("Missing head_dim or hidden_size to compute attention blocking configuration.") - return int(hidden_size) // int(num_heads) - - -def _infer_data_bytes(compile_config: Dict[str, Any]) -> int: - explicit = compile_config.get("data_bytes") - if explicit is not None: - return int(explicit) - if compile_config.get("convert_to_fp16", False): - return 2 - return 4 - - -def _normalize_attention_mode(raw_mode: str) -> str: - mode = raw_mode.lower() - if "q" in mode and "kv" in mode: - return "qkv" - if "kv" in mode: - return "kv" - if "q" in mode: - return "q" - return "" - - -def _resolve_effective_blocking_mode(attention_cfg: Dict[str, Any], requested_mode: str) -> str: - mode = _normalize_attention_mode(requested_mode) - if mode == "": - return "" - num_q_blocks = attention_cfg.get("num_q_blocks") or 1 - num_kv_blocks = attention_cfg.get("num_kv_blocks") or 1 - if num_q_blocks > 1 and num_kv_blocks > 1: - return "qkv" - if num_q_blocks > 1: - return "q" - if num_kv_blocks > 1: - return "kv" - return "" - - -def _get_valid_num_blocks(config: Dict, requested_key: str) -> int: - if config.get(requested_key) < 1: - raise ValueError(f"Invalid value {requested_key} passed in qaic_config: {config.get(requested_key)}") - return config.get(requested_key) - - -def block_candidates_generator(max_length: int) -> List[int]: - block_list = [] - i = 1 - step = 1 - while i <= max_length: - block_list.append(i) - if i % (4 * step) == 0: - step *= 2 - i += step - return block_list - - -def attention_configurator( - bs: int, - seq_len: int, - ctx_len: int, - num_heads: int, - head_dim: int, - num_socs: int, - num_nsps: int, - data_bytes: int, - blocking_mode: Optional[str] = None, -) -> Dict[str, Any]: - """ - Suggest attention blocking configuration based on model and device constraints. - """ - mode = (blocking_mode or "hqkv").lower() - - num_kv_blocks_list = block_candidates_generator(ctx_len) if "kv" in mode else [1] - num_q_blocks_list = block_candidates_generator(ctx_len) if "q" in mode else [1] - - head_block_size = num_socs if "h" in mode else num_heads - num_head_blocks = math.ceil(num_heads / head_block_size) - num_heads_per_iter = math.ceil(head_block_size / num_socs) - - best_config = { - "head_block_size": head_block_size, - "num_head_blocks": num_head_blocks, - "head_blocking_enabled": num_head_blocks > 1, - "num_q_blocks": None, - "num_kv_blocks": None, - "q_kv_ratio": None, - "vtcm_footprint": None, - } - - def update_best_config(num_q_blocks: int, num_kv_blocks: int, q_kv_ratio: float, footprint: float) -> None: - best_config["num_q_blocks"] = num_q_blocks - best_config["num_kv_blocks"] = num_kv_blocks - best_config["q_kv_ratio"] = q_kv_ratio - best_config["vtcm_footprint"] = footprint - - for num_q_blocks in num_q_blocks_list: - for num_kv_blocks in num_kv_blocks_list: - q_sl_per_nsp = math.ceil(seq_len / num_nsps / num_q_blocks) - q_size_per_nsp = num_heads_per_iter * bs * q_sl_per_nsp * head_dim * data_bytes - - kv_cl_per_nsp = math.ceil(ctx_len / num_kv_blocks) - kv_size_per_nsp = num_heads_per_iter * bs * kv_cl_per_nsp * head_dim * data_bytes - - qk_size_per_nsp = num_heads_per_iter * bs * q_sl_per_nsp * kv_cl_per_nsp * data_bytes - vtcm_footprint = q_size_per_nsp + kv_size_per_nsp + qk_size_per_nsp - q_kv_ratio = max(q_size_per_nsp / kv_size_per_nsp, kv_size_per_nsp / q_size_per_nsp) - num_total_blocks = num_q_blocks * num_kv_blocks - - if vtcm_footprint < VTCM_SIZE_THRESHOLD: - if best_config["num_q_blocks"] is None: - update_best_config(num_q_blocks, num_kv_blocks, q_kv_ratio, vtcm_footprint) - elif best_config["num_q_blocks"] * best_config["num_kv_blocks"] > num_total_blocks: - update_best_config(num_q_blocks, num_kv_blocks, q_kv_ratio, vtcm_footprint) - elif ( - best_config["num_q_blocks"] * best_config["num_kv_blocks"] == num_total_blocks - and best_config["q_kv_ratio"] >= q_kv_ratio - ): - update_best_config(num_q_blocks, num_kv_blocks, q_kv_ratio, vtcm_footprint) - break - - return best_config - - -def build_transformer_blocking_config( - model_config: Any, - pipeline_config: Optional[Any] = None, - module_name: str = "transformer", - blocking_mode: Optional[str] = None, - ctx_len: Optional[int] = None, - seq_len: Optional[int] = None, - bs: Optional[int] = 1, - compile_config: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """ - Build blocking configuration based on model config + pipeline compile config. - """ - if ctx_len is None: - ctx_len = seq_len - - if seq_len is None and ctx_len is None: - return AttentionBlockingConfig(mode="") - - num_heads = require_value( - get_attr_or_key(model_config, ("num_attention_heads", "num_heads", "attention_heads", "n_heads")), - "num attention heads", - ) - head_dim = _infer_head_dim(model_config, int(num_heads)) - - num_socs = int(compile_config.get("mdp_ts_num_devices", 1)) - num_nsps = int(compile_config.get("aic_num_cores", 1)) - data_bytes = _infer_data_bytes(compile_config) - - attention_cfg = attention_configurator( - int(bs), - int(seq_len), - int(ctx_len), - int(num_heads), - int(head_dim), - int(num_socs), - int(num_nsps), - int(data_bytes), - blocking_mode=blocking_mode, - ) - - resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") - effective_mode = _resolve_effective_blocking_mode(attention_cfg, resolved_mode) - - return AttentionBlockingConfig( - mode=effective_mode, - num_kv_blocks=attention_cfg["num_kv_blocks"], - num_q_blocks=attention_cfg["num_q_blocks"], - head_block_size=attention_cfg["head_block_size"], - ) - - -def build_transformer_blocking_config_for_transform( - model_config: Any, - ctx_len: Optional[int] = None, - seq_len: Optional[int] = None, - bs: Optional[int] = 1, - num_devices: Optional[int] = 1, - qaic_config: Optional[dict] = None, - **compile_options, -) -> Dict[str, Any]: - - if qaic_config: - blocking_mode = BlockingMode(qaic_config.get("blocking_mode", "hqkv")) - else: - blocking_mode = BlockingMode.HQKV - enable_blocking = False if not qaic_config else qaic_config.get("enable_blocking", False) - - if qaic_config is None and enable_blocking: - blocking_config = build_transformer_blocking_config( - model_config, - blocking_mode=blocking_mode, - ctx_len=ctx_len, - seq_len=seq_len, - bs=bs, - compile_config={"mdp_ts_num_devices": num_devices, **compile_options}, - ) - elif not enable_blocking: - blocking_config = None - else: - blocking_config = AttentionBlockingConfig() - mode_from_config = "" - if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode: - mode_from_config = "kv" + mode_from_config - blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") - if qaic_config.get("num_q_blocks", False) and enable_blocking and "q" in blocking_mode: - mode_from_config = "q" + mode_from_config - blocking_config.num_q_blocks = _get_valid_num_blocks(qaic_config, "num_q_blocks") - if qaic_config.get("head_block_size", False) and enable_blocking and "h" in blocking_mode: - mode_from_config = "h" + mode_from_config - blocking_config.head_block_size = _get_valid_num_blocks(qaic_config, "head_block_size") - - # check if qaic config did not provide any blocking details - if mode_from_config == "": - blocking_config = build_transformer_blocking_config( - model_config, - blocking_mode=blocking_mode, - ctx_len=ctx_len, - seq_len=seq_len, - bs=bs, - compile_config={"mdp_ts_num_devices": num_devices, **compile_options}, - ) - else: - blocking_config.mode = BlockingMode(mode_from_config) - - return blocking_config diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 4a7338e6aa..dcdd6f6246 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -8,7 +8,6 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, get_blocking_strategy, supports_blocked_kv from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 939a883a42..f48d49cd2f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -46,6 +46,7 @@ _configure_proxy_for_model, ) from QEfficient.transformers.models.pytorch_transforms import ( + BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, @@ -1369,33 +1370,6 @@ def export( ) return self.onnx_path - def transform( - self, - ctx_len: Optional[int] = None, - seq_len: Optional[int] = None, - bs: Optional[int] = 1, - num_devices: int = 1, - qaic_config: Optional[dict] = None, - **compiler_options, - ): - self.vision_model.transform( - ctx_len=ctx_len, - seq_len=seq_len, - bs=bs, - num_devices=num_devices, - qaic_config=qaic_config, - **compiler_options, - ) - - self.lang_model.transform( - ctx_len=ctx_len, - seq_len=seq_len, - bs=bs, - num_devices=num_devices, - qaic_config=qaic_config, - **compiler_options, - ) - def compile( self, img_size: Optional[int] = None, @@ -2014,6 +1988,9 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: + BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) + @classmethod def from_pretrained( cls, @@ -2818,14 +2795,8 @@ def __init__( self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config - self.model.pretrained_path = kwargs.pop("pretrained_model_name_or_path", None) - # self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) - # self.is_tlm = transformed - self.is_tlm = ( - (qaic_config is not None) - and (qaic_config.get("speculative_model_type") is not None) - and (model.__class__ in SpDTransform._module_mapping) - ) + self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) + self.is_tlm = transformed self.hash_params["qeff_auto_class"] = self.__class__.__name__ self.ccl_enabled = False @@ -2843,9 +2814,14 @@ def __init__( # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the # previous transform function. - # self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) + self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) # TODO : Update in qaic_config isn't updated in the hash due to SpDTransforms. Need to move # SpDTransforms to PytorchTransforms. + if self.is_tlm: + self.model.qaic_config["return_pdfs"] = True + + if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: + BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -3030,16 +3006,6 @@ def export( bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - # increase seq_len if using a larger number of blocks - if self.hash_params.get("blocking_kwargs", None): - max_blocks = -1 - for num_blocks in self.hash_params.get("blocking_kwargs").values(): - max_blocks = max(max_blocks, num_blocks) - block_size = -(-seq_len // max_blocks) - while seq_len < max_blocks or (seq_len % max_blocks > block_size): - seq_len = seq_len * 2 - block_size = -(-seq_len // max_blocks) - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS # kv_cache_shape = get_padding_shape_from_config( # self.model.config, fbs if self.continuous_batching else bs, seq_len diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index cd802bd48c..4a15273595 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1110,50 +1110,21 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu return model, transformed -def get_decoder_layer_classes_for_export(model: nn.Module) -> set: - """ - Dynamically determine which DecoderLayer classes should be exported as functions - based on the model's architecture using the existing KVCacheTransform mapping. - """ - # Define patterns that identify decoder layer classes - DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] - - # Get all QEff classes that are decoder layers from the existing mapping - decoder_layer_classes = set() - - for original_class, qeff_class in KVCacheTransform._module_mapping.items(): - # Check if the QEff class name contains decoder layer patterns - qeff_class_name = qeff_class.__name__ - if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): - decoder_layer_classes.add(qeff_class) - - # Filter to only include classes that are actually used in the current model - model_decoder_classes = set() - for module in model.modules(): - if module.__class__ in decoder_layer_classes: - model_decoder_classes.add(module.__class__) - - return model_decoder_classes - - -class BlockingAttentionTransform: - _skip_classes = {} +class BlockedKVAttentionTransform: + _module_mapping = { + QEffLlamaAttention, + QEffQwen2_5_VLAttention, + } @classmethod - def apply(cls, model: nn.Module, attn_blocking_config) -> Tuple[nn.Module, bool]: + def apply(cls, model: nn.Module, num_kv_blocks) -> Tuple[nn.Module, bool]: transformed = False - supported_attention_classes = { - qeff_class - for qeff_class in KVCacheTransform._module_mapping.values() - if qeff_class.__name__.endswith("Attention") - } for module in model.modules(): - if type(module) in cls._skip_classes: - warnings.warn(f"Blocking is not yet supported for {type(module)}.") - continue - if type(module) in supported_attention_classes or model.config.model_type == "kimi_k2": - module.attn_blocking_config = attn_blocking_config - transformed = True - elif module.__class__.__name__.endswith("Attention") and type(module) not in supported_attention_classes: - warnings.warn(f"Blocking is not yet supported for {type(module)}.") + if type(module) in cls._module_mapping: + repl_module = type(module) + module.__class__ = repl_module + module.forward = MethodType(partial(repl_module.forward, num_kv_blocks=num_kv_blocks), module) + transformed = True # Set to True if at least one transformation occurs + elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping: + warnings.warn(f"KV blocking is not yet supported for {type(module)}.") return model, transformed diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index d25198f9d6..3d6583f857 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -18,7 +18,6 @@ dump_qconfig, generate_mdp_partition_config, get_num_layers_from_config, - get_attr_or_key, get_num_layers_vlm, get_onnx_dir_name, get_padding_shape_from_config, @@ -35,7 +34,6 @@ onnx_exists, padding_check_and_fix, qpc_exists, - require_value, ) from QEfficient.utils.hash_utils import ( # noqa: F401 create_export_hash, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index e1b88dfab2..26bae7a34b 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -837,20 +837,3 @@ def custom_format_warning(msg, category, *args, **kwargs): YELLOW = "\033[93m" RESET = "\033[0m" return f"{YELLOW}[Warning]: {msg}{RESET}\n" - - -def get_attr_or_key(obj: Any, names: Tuple[str, ...], default: Any = None) -> Any: - if obj is None: - return default - for name in names: - if isinstance(obj, dict) and name in obj: - return obj[name] - if hasattr(obj, name): - return getattr(obj, name) - return default - - -def require_value(value: Any, label: str) -> Any: - if value is None: - raise ValueError(f"Missing required {label} to compute blocking configuration.") - return value diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index cc0b87b604..b3782605e1 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -33,9 +33,6 @@ ), } -# Blocking defaults -VTCM_SIZE_THRESHOLD = 8 * 1024 * 1024 * 0.75 - # Compiler defaults DEFAULT_AIC_NUM_CORES = 16 DEFAULT_AIC_MXPF6_MATMUL = False @@ -213,6 +210,7 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 + NUM_KV_BLOCKS = 8 MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS SAMPLER_OPS = { "repetition_penalties", diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index 4501b1a932..da3231190e 100644 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -131,7 +131,6 @@ def _generate_export_hash(qeff_model, args, kwargs, func): output_names=all_args.get("output_names"), dynamic_axes=all_args.get("dynamic_axes"), export_kwargs=all_args.get("export_kwargs", None), - blocking_kwargs=all_args.get("blocking_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), ) diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 4cb137895e..10e6686d0c 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -7,7 +7,6 @@ import hashlib import json -from dataclasses import asdict, is_dataclass from typing import Dict from QEfficient.utils.constants import HASH_HEXDIGEST_STR_LEN @@ -17,9 +16,6 @@ def json_serializable(obj): if isinstance(obj, set): # Convert set to a sorted list of strings for consistent hashing return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj]) - if is_dataclass(obj): - # Convert dataclass to dict for serialization - return asdict(obj) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") @@ -62,10 +58,6 @@ def create_export_hash(**kwargs): export_params["dynamic_axes"] = kwargs.get("dynamic_axes") export_hash_params["export_params"] = export_params - blocking_kwargs = export_hash_params.pop("blocking_kwargs", None) - if blocking_kwargs: - export_hash_params.update(blocking_kwargs) - export_kwargs = kwargs.get("export_kwargs") if export_kwargs: export_hash_params.update(export_kwargs) diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index b50dc4ed2f..735b3ace32 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -90,11 +90,6 @@ print("Completion:", repr(predicted_string)) - - - -qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} - prefill_seq_len = 1 ctx_len = 1024 @@ -103,12 +98,11 @@ ctx_len=ctx_len, enable_mla=enable_mla, mla_absorption_config=mla_absorption_config, - mxfp6_matmul=False, + mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=num_kv_heads_repeat, num_cores=16, #prefill_only=True, - qaic_config=qaic_config, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 900ee0b2402df96b096c7a458d351052a41bae61 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 9 Apr 2026 11:54:03 +0000 Subject: [PATCH 23/51] fixed basic forward Signed-off-by: Onkar Chougule --- .../deepseek_v3/modeling_deepseek_qeff.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index dcdd6f6246..4dd00a389e 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -484,8 +484,7 @@ def fused_forward_blocked_kv( # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] else: knope = torch.matmul(compressed_kv_block, self.per_head_k_up_normal) - breakpoint() - krope_nope = torch.cat((knope, k_pe_block.expand(-1,64,-1,-1)), dim=-1) + krope_nope = torch.cat((knope, k_pe_block.expand(-1,self.num_heads,-1,-1)), dim=-1) qrope_nope = torch.cat((q_nope, q_pe), dim=-1) attn_weights_block = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale attn_weights_block = torch.where( @@ -539,37 +538,24 @@ def fused_forward_basic( else: enable_absorption = False - n_head_ckv = compressed_kv.shape[1] - p = self.num_heads // n_head_ckv - - ############################################################################ - - - kva_expanded = kva.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.kv_lora_rank) - v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1,0,2) - - value_states=torch.matmul(kva_expanded, v_up_per_head) - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - k_pe_expanded = k_pe.unsqueeze(2).expand(-1,-1,p,-1,-1).reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) if enable_absorption: if absorb_online: print("online absorption") out = torch.matmul(self.per_head_q_up, self.per_head_k_up) - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) + q_nope_compressed = torch.matmul(q_a_proj_out, out) else: print("using fused qk") - #breakpoint() - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk) + q_nope_compressed = torch.matmul(q_a_proj_out, self.fusedqk) query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + key_states = torch.cat((kva, k_pe), dim=-1) else: print("no absorption") q_nope = torch.bmm(q_a_proj_out, self.q_up) @@ -577,8 +563,8 @@ def fused_forward_basic( query_states = torch.cat((q_nope, q_pe), dim=-1) k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) - k_nope = torch.matmul(kva_expanded, k_up_per_head) - key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + k_nope = torch.matmul(kva, k_up_per_head) + key_states = torch.cat((k_nope, k_pe.expand(-1,self.num_heads,-1,-1)), dim=-1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it @@ -587,12 +573,13 @@ def fused_forward_basic( ) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, compressed_kvs, value_states + return attn_output, attn_weights, compressed_kvs, None From f763671e62d0bf53244a4f06f293494d91a52e9d Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 9 Apr 2026 12:42:44 +0000 Subject: [PATCH 24/51] fixed for kimi k2 Signed-off-by: Onkar Chougule --- .../transformers/models/modeling_auto.py | 9 ++++--- examples/export_kimik2.py | 26 +++++++------------ 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f48d49cd2f..c814fe1b53 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3152,14 +3152,17 @@ def export( qaic_config=self.model.qaic_config, ) if enable_mla: - mdp_ts_num_devices = kwargs.get("mdp_ts_num_devices", 1) + for lay in self.model.model.layers: + if lay is not None: + num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0]//576 + example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} output_names = [v for v in output_names if "past" not in v] example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): - ckv = torch.zeros((bs, mdp_ts_num_devices, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) - k_pe = torch.zeros((bs, mdp_ts_num_devices, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) + ckv = torch.zeros((bs, num_heads, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) + k_pe = torch.zeros((bs, num_heads, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) example_inputs["compressed_kvs"][i].append(ckv) example_inputs["compressed_kvs"][i].append(k_pe) dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} diff --git a/examples/export_kimik2.py b/examples/export_kimik2.py index b9c461281d..66a702bca2 100644 --- a/examples/export_kimik2.py +++ b/examples/export_kimik2.py @@ -4,36 +4,30 @@ from QEfficient import QEFFAutoModelForCausalLM #parameters to be configured -TS=4 + +prompt = "Once upon a time," +num_kv_heads_repeat=1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number shoul be equal to TS in that case. num_hidden_layers=2 +TS=4 enable_mla=True -mla_absorption_config={"enable": True, "online": False} -prefill_seq_len = 1 -ctx_len = 2048 -qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} - +mla_absorption_config={"enable": True, "online": True} model = AutoModelForCausalLM.from_pretrained( - "moonshotai/Kimi-K2-Thinking", - torch_dtype=torch.float32, - num_hidden_layers=num_hidden_layers, - trust_remote_code=True, + "moonshotai/Kimi-K2-Thinking", torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat = TS) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat = num_kv_heads_repeat) qpc_path = qeff_model.compile( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, + prefill_seq_len=1, + ctx_len=16*1024, enable_mla=enable_mla, mla_absorption_config=mla_absorption_config, mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=TS, - num_cores=16, - #prefill_only=True, - qaic_config=qaic_config, + num_cores=16 ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From dc204951241b7e455e9f56ca781b12144ffcbd98 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 10 Apr 2026 09:36:50 +0000 Subject: [PATCH 25/51] fix lint format Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 1 - QEfficient/transformers/cache_utils.py | 10 +- .../models/deepseek_v3/modeling_deepseek.py | 1 - .../deepseek_v3/modeling_deepseek_qeff.py | 200 ++++++++++-------- .../transformers/models/modeling_auto.py | 8 +- .../transformers/models/pytorch_transforms.py | 20 +- examples/export_kimik2.py | 23 +- .../causallm/example_pytorch_transforms.py | 12 +- examples/run_kimik2.py | 22 +- 9 files changed, 156 insertions(+), 141 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2a7cb85e94..0c6767cf1f 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -366,7 +366,6 @@ def get_onnx_path( self.export(**kwargs) return self.onnx_path - @dump_qconfig def _compile( self, diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 8e478528f3..6ac41242e7 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -350,7 +350,6 @@ def __init__(self, ckv, k_pe): def update_ckv(self, compressed_kv, cache_kwargs): position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv) @@ -371,7 +370,6 @@ def update_ckv(self, compressed_kv, cache_kwargs): def update_k_pe(self, k_pe_cache, cache_kwargs): position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) @@ -394,7 +392,6 @@ def read_only_blocked_ckv(self, start_index, end_index, cache_kwargs): # Gather ckv_out = self.ckv position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) batch, num_kv_heads, _, _ = ckv_out.shape ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) @@ -417,7 +414,6 @@ def read_only_blocked_k_pe(self, start_index, end_index, cache_kwargs): # Gather k_pe_out = self.k_pe position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) batch, num_kv_heads, _, _ = k_pe_out.shape ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) @@ -438,14 +434,12 @@ def read_only_blocked_k_pe(self, start_index, end_index, cache_kwargs): def write_only_k_pe(self, k_pe_cache, cache_kwargs): position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) return self.k_pe def write_only_ckv(self, compressed_kv, cache_kwargs): position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv) return self.ckv @@ -473,11 +467,11 @@ def read_only_blocked_k_pe(self, start_index, end_index, layer_idx, cache_kwargs return self.layers[layer_idx].read_only_blocked_k_pe(start_index, end_index, cache_kwargs) def write_only_ckv(self, ckv, layer_idx, cache_kwargs): - #self.append_new_layers(layer_idx) + # self.append_new_layers(layer_idx) return self.layers[layer_idx].write_only_ckv(ckv, cache_kwargs) def write_only_k_pe(self, k_pe, layer_idx, cache_kwargs): - #self.append_new_layers(layer_idx) + # self.append_new_layers(layer_idx) return self.layers[layer_idx].write_only_k_pe(k_pe, cache_kwargs) @classmethod diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 5eff081888..261f3bf752 100755 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -463,7 +463,6 @@ def forward(self, hidden_states): orig_shape = hidden_states.shape topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - flat_topk_idx = topk_idx.view(-1) if not self.training: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) if self.config.n_shared_experts is not None: diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 4dd00a389e..f3659a60f1 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -1,7 +1,7 @@ import math -from typing import Dict, List, Optional, Tuple, Type, Union, Any - import os +from typing import Dict, List, Optional, Tuple, Type, Union + import torch import torch.nn.functional as F from torch import nn @@ -236,22 +236,22 @@ def update_running_softmax( return current_max, current_denominator, output - class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" + def __qeff_init__( self, ): q_up, q_rope = self.q_b_proj.weight.T.view( -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) self.q_up = torch.nn.Parameter(q_up.detach().clone()) - + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) - + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) @@ -272,7 +272,7 @@ def __qeff_init__( fusedqk_list = [] for i in range(self.num_heads): - fusedqk_list.append(torch.matmul(per_head_q_up[i,:,:], per_head_k_up[i,:,:])) + fusedqk_list.append(torch.matmul(per_head_q_up[i, :, :], per_head_k_up[i, :, :])) fusedqk = torch.cat(fusedqk_list, dim=0) fusedqk = fusedqk.reshape(1, self.num_heads, -1, self.kv_lora_rank) @@ -296,7 +296,7 @@ def fused_forward_h_blocking( bsz, q_len, _ = hidden_states.size() compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) @@ -304,7 +304,7 @@ def fused_forward_h_blocking( q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - + compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: @@ -326,46 +326,47 @@ def fused_forward_h_blocking( if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - attn_output_list = [] attn_weights_list = [] - for head_block_idx in range(self.num_heads//n_head_ckv): - h_start = head_block_idx * n_head_ckv - h_end = min(h_start+n_head_ckv, self.num_heads) - - if enable_absorption: - if absorb_online: - qup_kupT = torch.matmul(self.per_head_q_up[:, h_start:h_end,:,:], self.per_head_k_up[:, h_start:h_end,:,:]) - dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) - else: - dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk[:, h_start:h_end,:,:]) - qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) - krope_nope = torch.cat((kva, k_pe), dim=-1) - attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qkupTrope_nope.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) - - else: - knope = torch.matmul(kva, self.per_head_k_up_normal[:, h_start:h_end, :, :]) - krope_nope = torch.cat((knope, k_pe), dim=-1) - qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) - attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) + for head_block_idx in range(self.num_heads // n_head_ckv): + h_start = head_block_idx * n_head_ckv + h_end = min(h_start + n_head_ckv, self.num_heads) + + if enable_absorption: + if absorb_online: + qup_kupT = torch.matmul( + self.per_head_q_up[:, h_start:h_end, :, :], self.per_head_k_up[:, h_start:h_end, :, :] + ) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk[:, h_start:h_end, :, :]) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) + krope_nope = torch.cat((kva, k_pe), dim=-1) + attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qkupTrope_nope.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) + + else: + knope = torch.matmul(kva, self.per_head_k_up_normal[:, h_start:h_end, :, :]) + krope_nope = torch.cat((knope, k_pe), dim=-1) + qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) + attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) attn_output = torch.cat(attn_output_list, dim=1) attn_weights = torch.cat(attn_weights_list, dim=1) @@ -393,8 +394,8 @@ def fused_forward_blocked_kv( num_kv_blocks = 4 ctx_len = compressed_kvs.layers[0].ckv.shape[2] - kv_block_size = -(-ctx_len // num_kv_blocks) - + kv_block_size = -(-ctx_len // num_kv_blocks) + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -407,7 +408,6 @@ def fused_forward_blocked_kv( compressed_kv = self.kv_a_layernorm(compressed_kv) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) @@ -436,11 +436,10 @@ def fused_forward_blocked_kv( else: batch_size, num_heads, seq_len, _ = q_nope.shape - current_position = position_ids.max(dim=-1).values - skip_kv=True + skip_kv = True output = torch.zeros(batch_size, self.num_heads, seq_len, self.kv_lora_rank, device=hidden_states.device) - + current_max = torch.full( (batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE), @@ -463,34 +462,53 @@ def fused_forward_blocked_kv( if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): if skip_future.item(): break - - compressed_kv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, self.layer_idx, cache_kwargs) + + compressed_kv_block = compressed_kvs.read_only_blocked_ckv( + start_index, end_index, self.layer_idx, cache_kwargs + ) k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, self.layer_idx, cache_kwargs) - - + causal_mask_block = _create_causal_mask( position_ids=position_ids, target_length=end_index, start_index=start_index, ) - + if enable_absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) - attn_weights_block = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] + attn_weights_block = ( + torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale + ) # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] attn_weights_block = torch.where( causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block ) - current_max, current_denominator, output = update_running_softmax(current_max, attn_weights_block, current_denominator, output, compressed_kv_block, skip_kv, skip_future) + current_max, current_denominator, output = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output, + compressed_kv_block, + skip_kv, + skip_future, + ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] else: knope = torch.matmul(compressed_kv_block, self.per_head_k_up_normal) - krope_nope = torch.cat((knope, k_pe_block.expand(-1,self.num_heads,-1,-1)), dim=-1) + krope_nope = torch.cat((knope, k_pe_block.expand(-1, self.num_heads, -1, -1)), dim=-1) qrope_nope = torch.cat((q_nope, q_pe), dim=-1) attn_weights_block = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale attn_weights_block = torch.where( causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block ) - current_max, current_denominator, output = update_running_softmax(current_max, attn_weights_block, current_denominator, output, compressed_kv_block, skip_kv, skip_future) + current_max, current_denominator, output = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output, + compressed_kv_block, + skip_kv, + skip_future, + ) attn_output = torch.matmul(output, self.per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) @@ -544,7 +562,6 @@ def fused_forward_basic( if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - if enable_absorption: if absorb_online: print("online absorption") @@ -562,15 +579,17 @@ def fused_forward_basic( q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) query_states = torch.cat((q_nope, q_pe), dim=-1) - k_up_per_head = self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1,0,2) + k_up_per_head = ( + self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) + ) k_nope = torch.matmul(kva, k_up_per_head) - key_states = torch.cat((k_nope, k_pe.expand(-1,self.num_heads,-1,-1)), dim=-1) + key_states = torch.cat((k_nope, k_pe.expand(-1, self.num_heads, -1, -1)), dim=-1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) attn_output = torch.matmul(attn_weights, kva) @@ -580,8 +599,6 @@ def fused_forward_basic( attn_output = self.o_proj(attn_output) return attn_output, attn_weights, compressed_kvs, None - - def fused_forward( self, @@ -643,7 +660,7 @@ def fused_forward( mla_absorption, **kwargs, ) - + def forward( self, hidden_states: torch.Tensor, @@ -690,7 +707,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) @@ -702,7 +721,6 @@ def forward( return attn_output, attn_weights, past_key_value, value_states - class QEffDeepseekV3MoE(nn.Module): def __qeff_init__( self, @@ -768,17 +786,17 @@ def __qeff_init__( self, ): for exp in self.experts: - gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) - gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) - up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) - down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) + up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) + down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) - setattr(exp,"gate_proj", gate_proj) - setattr(exp,"up_proj", up_proj) - setattr(exp,"down_proj", down_proj) + setattr(exp, "gate_proj", gate_proj) + setattr(exp, "up_proj", up_proj) + setattr(exp, "down_proj", down_proj) def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) @@ -1040,15 +1058,16 @@ def forward( class QEffDeepseekV3ForCausalLM(nn.Module): """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" + def get_submodules_for_export(self) -> Type[nn.Module]: - """ - Return the set of class used as the repeated layer across the model for subfunction extraction. - Notes: - This method should return the *class object* (not an instance). - Downstream code can use this to find/build subfunctions for repeated blocks. - """ - return {self.model.layers[0].__class__} - + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.layers[0].__class__} + def forward( self, input_ids: torch.LongTensor = None, @@ -1067,7 +1086,6 @@ def forward( num_logits_to_keep: int = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - before_keys = self.state_dict().keys() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c814fe1b53..7287720f6e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,13 +51,13 @@ KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, - PrefillOnlyExternalModuleMapperTransform, PrefillOnlyChunkedTransform, + PrefillOnlyExternalModuleMapperTransform, PrefillOnlyTransform, ReplicateKVHeadTransform, RevertPrefillKeepAttentionTransform, - RevertPrefillOnlyTransform, RevertPrefillOnlyExternalModuleMapperTransform, + RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, TextClassificationTransform, @@ -3010,8 +3010,6 @@ def export( # kv_cache_shape = get_padding_shape_from_config( # self.model.config, fbs if self.continuous_batching else bs, seq_len # ) - ckv_shape = (1, seq_len, 512) - k_pe_shape = (1, 1, seq_len, 64) kv_cache_shape = (1, 64, seq_len, 192) kv_cache_shape_v = (1, 64, seq_len, 128) enable_chunking = kwargs.get("enable_chunking", False) @@ -3154,7 +3152,7 @@ def export( if enable_mla: for lay in self.model.model.layers: if lay is not None: - num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0]//576 + num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // 576 example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 4a15273595..01d3fa8aae 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -786,16 +786,16 @@ class ReplicateKVHeadTransform: def _duplicate_weights_for_linear_layer( layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int ): - new_kv_heads = repeat #for mla + new_kv_heads = repeat # for mla layer.weight.data = torch.repeat_interleave( layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 ).view(new_kv_heads * dim, hidden_size) if layer.bias is not None: - layer.bias.data = torch.repeat_interleave( - layer.bias.data.view(orig_kv_heads, dim), repeat, 0 - ).view(new_kv_heads * dim) + layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( + new_kv_heads * dim + ) def _get_text_model(model): """ @@ -828,12 +828,11 @@ def apply(cls, model: nn.Module, **kwargs) -> nn.Module: if n_repeat is not None and n_repeat > 1: text_model = cls._get_text_model(model) - orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads - new_kv_heads = n_repeat*orig_kv_heads + orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads + new_kv_heads = n_repeat * orig_kv_heads text_model.config.orig_kv_heads = orig_kv_heads text_model.config.num_key_value_heads = new_kv_heads - num_attention_heads = text_model.config.num_attention_heads hidden_size = text_model.config.hidden_size logger.warning(f"Original KV heads: {orig_kv_heads}") @@ -1051,6 +1050,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, } + class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} _match_string_replace_method = { @@ -1061,6 +1061,7 @@ class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): }, } + class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} _match_string_replace_method = { @@ -1070,10 +1071,11 @@ class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransfo "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, }, } - '''_match_string_replace_method = { + """_match_string_replace_method = { **{v: k for k, v in PrefillOnlyExternalModuleMapperTransform._match_string_replace_method.items()}, } - ''' + """ + class T5ModelTransform(ModuleMappingTransform): # supported architectures diff --git a/examples/export_kimik2.py b/examples/export_kimik2.py index 66a702bca2..b024aeb257 100644 --- a/examples/export_kimik2.py +++ b/examples/export_kimik2.py @@ -3,31 +3,34 @@ from QEfficient import QEFFAutoModelForCausalLM -#parameters to be configured +# parameters to be configured prompt = "Once upon a time," -num_kv_heads_repeat=1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number shoul be equal to TS in that case. -num_hidden_layers=2 -TS=4 -enable_mla=True -mla_absorption_config={"enable": True, "online": True} +num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number shoul be equal to TS in that case. +num_hidden_layers = 2 +TS = 4 +enable_mla = True +mla_absorption_config = {"enable": True, "online": True} model = AutoModelForCausalLM.from_pretrained( - "moonshotai/Kimi-K2-Thinking", torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True + "moonshotai/Kimi-K2-Thinking", + torch_dtype=torch.float32, + num_hidden_layers=num_hidden_layers, + trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat = num_kv_heads_repeat) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat) qpc_path = qeff_model.compile( prefill_seq_len=1, - ctx_len=16*1024, + ctx_len=16 * 1024, enable_mla=enable_mla, mla_absorption_config=mla_absorption_config, mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=TS, - num_cores=16 + num_cores=16, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12dc..ff62588f9c 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/examples/run_kimik2.py b/examples/run_kimik2.py index 735b3ace32..67388a899f 100644 --- a/examples/run_kimik2.py +++ b/examples/run_kimik2.py @@ -5,13 +5,15 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -num_kv_heads_repeat=4 #TS=4 -num_hidden_layers=2 -enable_mla=True -mla_absorption_config={"enable": False, "online": False} - -#model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" -model_path ="/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" +num_kv_heads_repeat = 4 # TS=4 +num_hidden_layers = 2 +enable_mla = True +mla_absorption_config = {"enable": False, "online": False} + +# model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +model_path = ( + "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" +) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) @@ -60,7 +62,7 @@ inputs["compressed_kvs"] = compressed_kvs -#inputs["past_key_values"] = past_key_values +# inputs["past_key_values"] = past_key_values prefill_qeff_out = qeff_model.model(**inputs) @@ -78,7 +80,7 @@ "input_ids": next_token_id, "position_ids": position_ids, "compressed_kvs": qeff_out["past_key_values"], - #"past_key_values": qeff_out["past_key_values"], + # "past_key_values": qeff_out["past_key_values"], } qeff_out = qeff_model.model(**decode_inputs) @@ -102,7 +104,7 @@ mxint8_kv_cache=False, num_devices=num_kv_heads_repeat, num_cores=16, - #prefill_only=True, + # prefill_only=True, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 877a2def7c532792f745b9c1764052e72e518c64 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 10 Apr 2026 10:14:48 +0000 Subject: [PATCH 26/51] remove redundencies Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 16 +--- .../deepseek_v3/configuration_deepseek.py | 7 ++ .../models/deepseek_v3/modeling_deepseek.py | 22 +---- .../deepseek_v3/modeling_deepseek_orig.py | 22 +---- .../deepseek_v3/modeling_deepseek_qeff.py | 7 ++ .../transformers/models/modeling_auto.py | 80 +++++++++++-------- .../transformers/models/pytorch_transforms.py | 4 - examples/{ => kimi_k2}/export_kimik2.py | 9 ++- examples/{ => kimi_k2}/run_kimik2.py | 31 ++++--- examples/{ => kimi_k2}/run_orig_kimi_k2.py | 7 ++ .../causallm/example_pytorch_transforms.py | 12 +-- 11 files changed, 113 insertions(+), 104 deletions(-) rename examples/{ => kimi_k2}/export_kimik2.py (77%) rename examples/{ => kimi_k2}/run_kimik2.py (75%) rename examples/{ => kimi_k2}/run_orig_kimi_k2.py (74%) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 0c6767cf1f..b26f0a636e 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -271,17 +271,8 @@ def _export( ) elif param == "compressed_kvs": for i in range(len(example_inputs["compressed_kvs"])): - # input_names.extend([f"compressed_kvs.{i}",]) - input_names.extend( - [ - f"compressed_kv.{i}", - ] - ) - input_names.extend( - [ - f"k_pe.{i}", - ] - ) + input_names.extend([f"compressed_kv.{i}",]) + input_names.extend([f"k_pe.{i}",]) else: input_names.append(param) @@ -343,7 +334,6 @@ def get_onnx_path( retain_full_kv: Optional[bool] = False, enable_mla: Optional[bool] = False, mla_absorption_config: Optional[bool] = False, - mdp_ts_num_devices: Optional[int] = 1, ): kwargs = { "offload_pt_weights": offload_pt_weights, @@ -351,7 +341,6 @@ def get_onnx_path( "retain_full_kv": retain_full_kv, "enable_mla": enable_mla, "mla_absorption_config": mla_absorption_config, - "mdp_ts_num_devices": mdp_ts_num_devices, } if prefill_only: @@ -426,7 +415,6 @@ def _compile( retain_full_kv, enable_mla, mla_absorption_config, - mdp_ts_num_devices, ) ) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py b/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py index ece0a5e075..7f68c3d86e 100644 --- a/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 261f3bf752..378d5577fc 100755 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1,23 +1,9 @@ -# coding=utf-8 -# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# ----------------------------------------------------------------------------- # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch DeepSeek model.""" +# ---------------------------------------------------------------------------- import math import warnings diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py index 8855ee88b5..f9566a491d 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py @@ -1,23 +1,9 @@ -# coding=utf-8 -# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# ----------------------------------------------------------------------------- # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch DeepSeek model.""" +# ---------------------------------------------------------------------------- import math import warnings diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index f3659a60f1..fa05a67d14 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + import math import os from typing import Dict, List, Optional, Tuple, Type, Union diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7287720f6e..b24faf3ba9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -75,6 +75,7 @@ ) from QEfficient.utils import ( constants, + get_padding_shape_from_config, ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger @@ -2805,7 +2806,7 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached - if self.model.config.model_type in {"kimi_k2"}: + if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) if replicate_kv_transformed: self.hash_params["config"] = model.config.to_diff_dict() @@ -3005,13 +3006,11 @@ def export( """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - # kv_cache_shape = get_padding_shape_from_config( - # self.model.config, fbs if self.continuous_batching else bs, seq_len - # ) - kv_cache_shape = (1, 64, seq_len, 192) - kv_cache_shape_v = (1, 64, seq_len, 128) + + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) enable_chunking = kwargs.get("enable_chunking", False) # TODO: HACK handle better @@ -3022,7 +3021,7 @@ def export( self.hash_params["mla_absorption_config"] = mla_absorption_config setattr(self.model.model, "mla_absorption_config", mla_absorption_config) - if self.model.config.model_type in {"kimi_k2"}: + if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: if prefill_only: self.prefill(enable=True) self.hash_params["prefill_only"] = True @@ -3122,14 +3121,47 @@ def export( else pkv_dynamic_axes ) + if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: + if enable_mla: + for lay in self.model.model.layers: + if lay is not None: + num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // 576 + + example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} + dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} + output_names = [v for v in output_names if "past" not in v] + example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] + for i in range(self.num_layers): + ckv = torch.zeros((bs, num_heads, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) + k_pe = torch.zeros((bs, num_heads, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) + example_inputs["compressed_kvs"][i].append(ckv) + example_inputs["compressed_kvs"][i].append(k_pe) + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} + output_names.append(f"compressed_kv.{i}_RetainedState") + output_names.append(f"k_pe.{i}_RetainedState") + + else: + cache_shape_k = ( + 1, + self.model.config.num_attention_heads, + seq_len, + self.model.config.qk_nope_head_dim + self.model.config.qk_rope_head_dim, + ) + cache_shape_v = (1, self.model.config.num_attention_heads, seq_len, self.model.config.v_head_dim) + for i in range(self.num_layers): + example_inputs["past_key_values"][i].append(torch.zeros(cache_shape_k, dtype=torch.float32)) + example_inputs["past_key_values"][i].append(torch.zeros(cache_shape_v, dtype=torch.float32)) + dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes[i] + dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_key.{i}_RetainedState") + output_names.append(f"past_value.{i}_RetainedState") + else: for i in range(self.num_layers): - # for kv in ["key", "value"]: - example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape_v, dtype=torch.float32)) - dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes[i] - dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_key.{i}_RetainedState") - output_names.append(f"past_value.{i}_RetainedState") + for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) @@ -3149,24 +3181,6 @@ def export( vocab_size=self.model.config.vocab_size, qaic_config=self.model.qaic_config, ) - if enable_mla: - for lay in self.model.model.layers: - if lay is not None: - num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // 576 - - example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} - dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} - output_names = [v for v in output_names if "past" not in v] - example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] - for i in range(self.num_layers): - ckv = torch.zeros((bs, num_heads, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) - k_pe = torch.zeros((bs, num_heads, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) - example_inputs["compressed_kvs"][i].append(ckv) - example_inputs["compressed_kvs"][i].append(k_pe) - dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} - dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} - output_names.append(f"compressed_kv.{i}_RetainedState") - output_names.append(f"k_pe.{i}_RetainedState") return self._export( example_inputs, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 01d3fa8aae..2eeeba1562 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1071,10 +1071,6 @@ class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransfo "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, }, } - """_match_string_replace_method = { - **{v: k for k, v in PrefillOnlyExternalModuleMapperTransform._match_string_replace_method.items()}, - } - """ class T5ModelTransform(ModuleMappingTransform): diff --git a/examples/export_kimik2.py b/examples/kimi_k2/export_kimik2.py similarity index 77% rename from examples/export_kimik2.py rename to examples/kimi_k2/export_kimik2.py index b024aeb257..6b9cc969ed 100644 --- a/examples/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -6,7 +13,7 @@ # parameters to be configured prompt = "Once upon a time," -num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number shoul be equal to TS in that case. +num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True diff --git a/examples/run_kimik2.py b/examples/kimi_k2/run_kimik2.py similarity index 75% rename from examples/run_kimik2.py rename to examples/kimi_k2/run_kimik2.py index 67388a899f..379b99bda1 100644 --- a/examples/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -5,15 +12,14 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -num_kv_heads_repeat = 4 # TS=4 +num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. num_hidden_layers = 2 +TS = 4 enable_mla = True mla_absorption_config = {"enable": False, "online": False} -# model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" -model_path = ( - "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" -) +model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +# model_path = "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) @@ -41,10 +47,15 @@ inputs.pop("token_type_ids", None) inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -pad_shape_k = (1, 64, CTX_LEN, 192) -pad_shape_v = (1, 64, CTX_LEN, 128) -pad_shape_ckv = (1, num_kv_heads_repeat, CTX_LEN, 512) -pad_shape_k_pe = (1, num_kv_heads_repeat, CTX_LEN, 64) +pad_shape_k = ( + 1, + model.config.num_attention_heads, + CTX_LEN, + model.config.qk_nope_head_dim + model.config.qk_rope_head_dim, +) +pad_shape_v = (1, model.config.num_attention_heads, CTX_LEN, model.config.v_head_dim) +pad_shape_ckv = (1, num_kv_heads_repeat, CTX_LEN, model.config.kv_lora_rank) +pad_shape_k_pe = (1, num_kv_heads_repeat, CTX_LEN, model.config.qk_rope_head_dim) past_key_values = [] compressed_kvs = [] @@ -102,7 +113,7 @@ mla_absorption_config=mla_absorption_config, mxfp6_matmul=True, mxint8_kv_cache=False, - num_devices=num_kv_heads_repeat, + num_devices=TS, num_cores=16, # prefill_only=True, ) diff --git a/examples/run_orig_kimi_k2.py b/examples/kimi_k2/run_orig_kimi_k2.py similarity index 74% rename from examples/run_orig_kimi_k2.py rename to examples/kimi_k2/run_orig_kimi_k2.py index 558329fbfb..03052dfc31 100644 --- a/examples/run_orig_kimi_k2.py +++ b/examples/kimi_k2/run_orig_kimi_k2.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + import torch from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9c..503efc12dc 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, From 38910d31b39dc6e923343849ab8da0e7da0d018e Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 10 Apr 2026 19:27:25 +0000 Subject: [PATCH 27/51] fix no absorption for head blocking Signed-off-by: Mamta Singh --- .../transformers/models/deepseek_v3/modeling_deepseek_qeff.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index fa05a67d14..4a266a630d 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -275,7 +275,8 @@ def __qeff_init__( self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) - self.per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) + per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) + self.per_head_k_up_normal = torch.nn.Parameter(per_head_k_up_normal.detach().clone()) fusedqk_list = [] for i in range(self.num_heads): From a0f0e61e5ec53dfa246dbd582731977af2a6e9ea Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 13 Apr 2026 05:24:27 +0000 Subject: [PATCH 28/51] Add head blocking in full pkv forward Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 12 +- .../deepseek_v3/modeling_deepseek_qeff.py | 128 +++++++++++++++++- .../transformers/models/modeling_auto.py | 4 +- .../transformers/models/pytorch_transforms.py | 2 + dbg.log | 0 examples/kimi_k2/run_kimik2.py | 29 ++-- .../causallm/example_pytorch_transforms.py | 12 +- 7 files changed, 160 insertions(+), 27 deletions(-) create mode 100644 dbg.log diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b26f0a636e..ff5f1c1b04 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -271,8 +271,16 @@ def _export( ) elif param == "compressed_kvs": for i in range(len(example_inputs["compressed_kvs"])): - input_names.extend([f"compressed_kv.{i}",]) - input_names.extend([f"k_pe.{i}",]) + input_names.extend( + [ + f"compressed_kv.{i}", + ] + ) + input_names.extend( + [ + f"k_pe.{i}", + ] + ) else: input_names.append(param) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 4a266a630d..fe81cc76a3 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -372,14 +372,16 @@ def fused_forward_h_blocking( ) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) + attn_output = torch.matmul( + attn_output, self.per_head_v_up[:, h_start:h_end, :, :] + ) # TODO: merge this matmul with o_proj attn_output_list.append(attn_output) attn_weights_list.append(attn_weights) attn_output = torch.cat(attn_output_list, dim=1) attn_weights = torch.cat(attn_weights_list, dim=1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) # 7168, 8192 return attn_output, attn_weights, compressed_kvs, None @@ -518,7 +520,7 @@ def fused_forward_blocked_kv( skip_future, ) - attn_output = torch.matmul(output, self.per_head_v_up) + attn_output = torch.matmul(output, self.per_head_v_up) # TODO: merge this matmul with o_proj attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -572,17 +574,14 @@ def fused_forward_basic( if enable_absorption: if absorb_online: - print("online absorption") out = torch.matmul(self.per_head_q_up, self.per_head_k_up) q_nope_compressed = torch.matmul(q_a_proj_out, out) else: - print("using fused qk") q_nope_compressed = torch.matmul(q_a_proj_out, self.fusedqk) query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) key_states = torch.cat((kva, k_pe), dim=-1) else: - print("no absorption") q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) query_states = torch.cat((q_nope, q_pe), dim=-1) @@ -669,7 +668,7 @@ def fused_forward( **kwargs, ) - def forward( + def forward_full_kv( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], @@ -728,6 +727,121 @@ def forward( return attn_output, attn_weights, past_key_value, value_states + def forward_full_kv_h_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view( + bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) # TODO : split this matmul #with k_up and v_up + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) + key_states = torch.cat((k_nope, k_pe_new), -1) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + n_head_ckv = 4 # compressed_kv.shape[1] + + attn_output_list = [] + attn_weights_list = [] + for head_block_idx in range(self.num_heads // n_head_ckv): + h_start = head_block_idx * n_head_ckv + h_end = min(h_start + n_head_ckv, self.num_heads) + + attn_weights = ( + torch.matmul(query_states[:, h_start:h_end, :, :], key_states[:, h_start:h_end, :, :].transpose(2, 3)) + * self.softmax_scale + ) + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states[:, h_start:h_end, :, :]) + attn_output_list.append(attn_output) + attn_weights_list.append(attn_weights) + + attn_output = torch.cat(attn_output_list, dim=1) + attn_weights = torch.cat(attn_weights_list, dim=1) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value, value_states + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if os.environ.get("KIMI_BLOCKING", "0") == "h": + return self.forward_full_kv_h_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + batch_index, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + else: + return self.forward_full_kv( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + batch_index, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + class QEffDeepseekV3MoE(nn.Module): def __qeff_init__( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b24faf3ba9..e376fc4a27 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3133,7 +3133,9 @@ def export( example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): ckv = torch.zeros((bs, num_heads, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) - k_pe = torch.zeros((bs, num_heads, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32) + k_pe = torch.zeros( + (bs, num_heads, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32 + ) example_inputs["compressed_kvs"][i].append(ckv) example_inputs["compressed_kvs"][i].append(k_pe) dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2eeeba1562..1377268a04 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1039,6 +1039,8 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, "DeepseekV3Attention": { "forward": QEffDeepseekV3Attention.forward, + "forward_full_kv": QEffDeepseekV3Attention.forward_full_kv, + "forward_full_kv_h_blocking": QEffDeepseekV3Attention.forward_full_kv_h_blocking, "fused_forward": QEffDeepseekV3Attention.fused_forward, "fused_forward_blocked_kv": QEffDeepseekV3Attention.fused_forward_blocked_kv, "fused_forward_basic": QEffDeepseekV3Attention.fused_forward_basic, diff --git a/dbg.log b/dbg.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index 379b99bda1..e837d58cfa 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -18,8 +18,10 @@ enable_mla = True mla_absorption_config = {"enable": False, "online": False} -model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" -# model_path = "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" +# model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +model_path = ( + "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" +) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) @@ -35,9 +37,9 @@ num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len -with torch.no_grad(): - out = model(**inputs) - predictions = torch.argmax(out.logits, dim=-1) +# with torch.no_grad(): +# out = model(**inputs) +# predictions = torch.argmax(out.logits, dim=-1) qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat) qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config) @@ -72,13 +74,15 @@ compressed_kvs.append(x) -inputs["compressed_kvs"] = compressed_kvs -# inputs["past_key_values"] = past_key_values +if enable_mla: + inputs["compressed_kvs"] = compressed_kvs +else: + inputs["past_key_values"] = past_key_values prefill_qeff_out = qeff_model.model(**inputs) breakpoint() -assert (prefill_qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 +# assert (prefill_qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 position_ids = inputs["position_ids"] qeff_out = prefill_qeff_out @@ -90,9 +94,12 @@ decode_inputs = { "input_ids": next_token_id, "position_ids": position_ids, - "compressed_kvs": qeff_out["past_key_values"], - # "past_key_values": qeff_out["past_key_values"], } + if enable_mla: + decode_inputs["compressed_kvs"] = qeff_out["past_key_values"] + else: + decode_inputs["past_key_values"] = qeff_out["past_key_values"] + qeff_out = qeff_model.model(**decode_inputs) qeff_generated_ids.append(qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) @@ -104,7 +111,7 @@ prefill_seq_len = 1 -ctx_len = 1024 +ctx_len = 16 * 1024 qpc_path = qeff_model.compile( prefill_seq_len=prefill_seq_len, diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12dc..ff62588f9c 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, From 5b56df49f074c4030f13c5b01bf88bccf3da163d Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 13 Apr 2026 08:02:50 +0000 Subject: [PATCH 29/51] fix tests Signed-off-by: Mamta Singh --- .../transformers/models/modeling_auto.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index e376fc4a27..ed721962d3 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3121,6 +3121,12 @@ def export( else pkv_dynamic_axes ) + for i in range(self.num_layers): + for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_{kv}.{i}_RetainedState") + if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: if enable_mla: for lay in self.model.model.layers: @@ -3151,19 +3157,10 @@ def export( self.model.config.qk_nope_head_dim + self.model.config.qk_rope_head_dim, ) cache_shape_v = (1, self.model.config.num_attention_heads, seq_len, self.model.config.v_head_dim) + example_inputs["past_key_values"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): example_inputs["past_key_values"][i].append(torch.zeros(cache_shape_k, dtype=torch.float32)) example_inputs["past_key_values"][i].append(torch.zeros(cache_shape_v, dtype=torch.float32)) - dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes[i] - dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_key.{i}_RetainedState") - output_names.append(f"past_value.{i}_RetainedState") - else: - for i in range(self.num_layers): - for kv in ["key", "value"]: - example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) @@ -3421,7 +3418,7 @@ def compile( """ if mla_absorption_config and not enable_mla: - logger.warning("enable_mla_fusion will be ignored as enable_mla is set to False") + logger.warning("mla_absorption_config will be ignored as enable_mla is set to False") if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: logger.warning( "`kv_cache_batch_size` or `full_batch_size` is being passed" From 4d4dd0c8245c13668461db38b6bfa4457f4ab25a Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 13 Apr 2026 18:09:23 +0000 Subject: [PATCH 30/51] fix CI errors Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 2 ++ .../deepseek_v3/modeling_deepseek_qeff.py | 21 +++++++++++++++++++ .../models/grok_1/modeling_grok1.py | 2 +- .../transformers/models/pytorch_transforms.py | 3 ++- 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ff5f1c1b04..dd761fdf3b 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -236,6 +236,8 @@ def _export( """ # TODO: Hack for retain_full_kv, handle this outside export_kwargs.pop("retain_full_kv", None) + export_kwargs.pop("enable_mla", None) + export_kwargs.pop("mla_absorption_config", None) onnx_path = export_dir / f"{self.model_name}.onnx" # Return early if ONNX already exists diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index fe81cc76a3..8c7b415ddf 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -15,6 +15,7 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from QEfficient.customop.rms_norm import CustomRMSNormFunc from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -54,6 +55,26 @@ def yarn_linear_ramp_mask(min, max, dim): return ramp_func +class QEffDeepseekV3CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def forward(self, hidden_states): + """ + Forward pass of the RMSNorm module. + + Args: + hidden_states (torch.Tensor): Input tensor to be normalized. + + Returns: + torch.Tensor: Normalized tensor. + """ + return CustomRMSNormFunc.apply( + hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + ) + + class DeepseekV3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 0f88fe1b92..51bdaa4ea4 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -39,7 +39,7 @@ def forward(self, hidden_states): torch.Tensor: Normalized tensor. """ return CustomRMSNormFunc.apply( - hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + hidden_states, self.scale, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 1377268a04..3ce533f561 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -252,6 +252,7 @@ ) from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import ( QEffDeepseekV3Attention, + QEffDeepseekV3CustomRMSNormAIC, QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3Model, @@ -1048,7 +1049,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, "DeepseekV3RMSNorm": { - "forward": QEFFGrok1CustomRMSNormAIC.forward, + "forward": QEffDeepseekV3CustomRMSNormAIC.forward, }, } From 37dd1bb7bde4d1ace5c62bd11e50d4d13fea5686 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 15 Apr 2026 15:52:58 +0000 Subject: [PATCH 31/51] added orig forward from snippet Signed-off-by: Mamta Singh --- .../deepseek_v3/modeling_deepseek_qeff.py | 152 +++++++++++++++++- .../transformers/models/pytorch_transforms.py | 1 + examples/kimi_k2/run_kimik2.py | 2 +- 3 files changed, 149 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 8c7b415ddf..c6648d703c 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -404,7 +404,7 @@ def fused_forward_h_blocking( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) # 7168, 8192 - return attn_output, attn_weights, compressed_kvs, None + return attn_output, attn_weights, compressed_kvs def fused_forward_blocked_kv( self, @@ -545,7 +545,7 @@ def fused_forward_blocked_kv( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - return attn_output, None, compressed_kvs, None + return attn_output, None, compressed_kvs def fused_forward_basic( self, @@ -626,7 +626,134 @@ def fused_forward_basic( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, compressed_kvs, None + return attn_output, attn_weights, compressed_kvs + + + def fused_forward_orig( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + print("using orig forward") + + # ---- KV compression ---- + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view( + bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim + ).transpose(1, 2) + + compressed_kv, k_pe = compressed_kv.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # ---- Q projections ---- + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + + q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = q_pe.view( + bsz, q_len, self.num_heads, self.qk_rope_head_dim + ).transpose(1, 2) + + compressed_kv = self.kv_a_layernorm(compressed_kv) + + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + compressed_kv = compressed_kvs.update_ckv( + compressed_kv, self.layer_idx, cache_kwargs + ) + + kva = compressed_kv + + # ---- MLA absorption flags ---- + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) + else: + enable_absorption = False + + n_head_ckv = kva.shape[1] + p = self.num_heads // n_head_ckv + seq_kv = kva.shape[2] + + # ---- Rotary ---- + cos, sin = self.rotary_emb(q_pe, seq_len=32 * 1024) # Doesn't need q_pe as head_dim is initialized + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe( + k_pe, self.layer_idx, cache_kwargs + ) + + kva_expanded = kva.unsqueeze(2).expand(-1, -1, p, -1, -1) \ + .reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) + + k_pe_expanded = k_pe.unsqueeze(2).expand(-1, -1, p, -1, -1) \ + .reshape(bsz, self.num_heads, seq_kv, self.qk_rope_head_dim) + + v_up_per_head = self.v_up.squeeze(0) \ + .view(self.kv_lora_rank, self.num_heads, self.v_head_dim) \ + .permute(1, 0, 2) + + value_states = torch.matmul(kva_expanded, v_up_per_head) + + if enable_absorption: + if absorb_online: + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + q_nope_compressed = torch.matmul( + q_a_proj_out.unsqueeze(1), + self.fusedqk, + ) + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + else: + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view( + bsz, q_len, self.num_heads, self.qk_nope_head_dim + ).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) + + k_up_per_head = self.k_up.squeeze(0) \ + .view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim) \ + .permute(1, 0, 2) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype), + attn_weights, + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + ## Do v_proj here + attn_output = torch.matmul( + attn_weights, value_states + ) + attn_output = ( + attn_output.transpose(1, 2) + .contiguous() + .reshape(bsz, q_len, self.num_heads * self.v_head_dim) + ) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, compressed_kvs + def fused_forward( self, @@ -673,7 +800,7 @@ def fused_forward( mla_absorption, **kwargs, ) - else: + elif os.environ.get("KIMI_BLOCKING", "0") == "basic": return self.fused_forward_basic( hidden_states, position_embeddings, @@ -688,6 +815,21 @@ def fused_forward( mla_absorption, **kwargs, ) + else: + return self.fused_forward_orig( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) def forward_full_kv( self, @@ -1030,7 +1172,7 @@ def forward( residual = hidden_states orig_hidden_states = self.input_layernorm(hidden_states) if enable_mla: - hidden_states, self_attn_weights, present_compressed_kvs, vs = self.self_attn.fused_forward( + hidden_states, self_attn_weights, present_compressed_kvs = self.self_attn.fused_forward( hidden_states=orig_hidden_states, attention_mask=attention_mask, position_ids=position_ids, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 30f44c710b..cdb3f596a6 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1044,6 +1044,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "fused_forward": QEffDeepseekV3Attention.fused_forward, "fused_forward_blocked_kv": QEffDeepseekV3Attention.fused_forward_blocked_kv, "fused_forward_basic": QEffDeepseekV3Attention.fused_forward_basic, + "fused_forward_orig": QEffDeepseekV3Attention.fused_forward_orig, "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index e837d58cfa..cb5f54580e 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -12,7 +12,7 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. +num_kv_heads_repeat = 4 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True From 210d953f1ce24258df7ca31191f248b86533d3f1 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Thu, 16 Apr 2026 20:18:50 +0000 Subject: [PATCH 32/51] add support for subfunctions Signed-off-by: Mamta Singh --- QEfficient/base/onnx_transforms.py | 17 ++- .../deepseek_v3/modeling_deepseek_qeff.py | 124 +++++++++--------- .../transformers/models/pytorch_transforms.py | 1 + QEfficient/utils/export_utils.py | 3 +- examples/kimi_k2/run_kimik2.py | 4 +- 5 files changed, 74 insertions(+), 75 deletions(-) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 2ba53829a4..c27e3cc704 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -150,13 +150,16 @@ def apply(cls, model: ModelProto) -> bool: if "_InternalRetainedState" in out_name: renamed = True orig = node.output[i] - new = ( - f"past_key.{layer_idx}_RetainedState" - if "key" in out_name - else f"past_value.{layer_idx}_RetainedState" - if "value" in out_name - else orig - ) + if "key" in out_name: + new = f"past_key.{layer_idx}_RetainedState" + elif "value" in out_name: + new = f"past_value.{layer_idx}_RetainedState" + elif "compressed_kv" in out_name: + new = f"compressed_kv.{layer_idx}_RetainedState" + elif "k_pe" in out_name: + new = f"k_pe.{layer_idx}_RetainedState" + else: + new = orig node.output[i] = new if orig in model_out_map: graph.output[model_out_map[orig]].name = new diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index c6648d703c..b4f78b46a2 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -326,7 +326,9 @@ def fused_forward_h_blocking( compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.matmul(q_a_proj_out, self.q_rope) @@ -334,12 +336,10 @@ def fused_forward_h_blocking( q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - compressed_kv = self.kv_a_layernorm(compressed_kv) + kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: - compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) - - kva = compressed_kv + kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) @@ -347,7 +347,7 @@ def fused_forward_h_blocking( else: enable_absorption = False - n_head_ckv = compressed_kv.shape[1] + n_head_ckv = kva.shape[1] cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) @@ -429,7 +429,9 @@ def fused_forward_blocked_kv( compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.matmul(q_a_proj_out, self.q_rope) @@ -437,7 +439,7 @@ def fused_forward_blocked_kv( q_nope = torch.matmul(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - compressed_kv = self.kv_a_layernorm(compressed_kv) + kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if mla_absorption is not None: @@ -448,13 +450,13 @@ def fused_forward_blocked_kv( ## Write Only if compressed_kvs is not None: - compressed_kv = compressed_kvs.write_only_ckv(compressed_kv, self.layer_idx, cache_kwargs) + compressed_kvs.write_only_ckv(kva, self.layer_idx, cache_kwargs) cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) + compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) if enable_absorption: if absorb_online: @@ -566,7 +568,9 @@ def fused_forward_basic( compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.bmm(q_a_proj_out, self.q_rope) @@ -574,12 +578,10 @@ def fused_forward_basic( q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - compressed_kv = self.kv_a_layernorm(compressed_kv) + kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: - compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs) - - kva = compressed_kv + kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) @@ -628,7 +630,6 @@ def fused_forward_basic( return attn_output, attn_weights, compressed_kvs - def fused_forward_orig( self, hidden_states: torch.Tensor, @@ -644,37 +645,29 @@ def fused_forward_orig( mla_absorption: Optional[Dict[str, bool]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() print("using orig forward") # ---- KV compression ---- compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view( - bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim - ).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - compressed_kv, k_pe = compressed_kv.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] # ---- Q projections ---- q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.bmm(q_a_proj_out, self.q_rope) - q_pe = q_pe.view( - bsz, q_len, self.num_heads, self.qk_rope_head_dim - ).transpose(1, 2) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - compressed_kv = self.kv_a_layernorm(compressed_kv) + kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if compressed_kvs is not None: - compressed_kv = compressed_kvs.update_ckv( - compressed_kv, self.layer_idx, cache_kwargs - ) + kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) - kva = compressed_kv + # kva = compressed_kv # ---- MLA absorption flags ---- if mla_absorption is not None: @@ -692,19 +685,17 @@ def fused_forward_orig( q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe( - k_pe, self.layer_idx, cache_kwargs - ) + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - kva_expanded = kva.unsqueeze(2).expand(-1, -1, p, -1, -1) \ - .reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) + kva_expanded = ( + kva.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) + ) - k_pe_expanded = k_pe.unsqueeze(2).expand(-1, -1, p, -1, -1) \ - .reshape(bsz, self.num_heads, seq_kv, self.qk_rope_head_dim) + k_pe_expanded = ( + k_pe.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.qk_rope_head_dim) + ) - v_up_per_head = self.v_up.squeeze(0) \ - .view(self.kv_lora_rank, self.num_heads, self.v_head_dim) \ - .permute(1, 0, 2) + v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) value_states = torch.matmul(kva_expanded, v_up_per_head) @@ -721,14 +712,12 @@ def fused_forward_orig( key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) else: q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view( - bsz, q_len, self.num_heads, self.qk_nope_head_dim - ).transpose(1, 2) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) query_states = torch.cat((q_nope, q_pe), dim=-1) - k_up_per_head = self.k_up.squeeze(0) \ - .view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim) \ - .permute(1, 0, 2) + k_up_per_head = ( + self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) + ) k_nope = torch.matmul(kva_expanded, k_up_per_head) key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) @@ -742,19 +731,12 @@ def fused_forward_orig( ) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) ## Do v_proj here - attn_output = torch.matmul( - attn_weights, value_states - ) - attn_output = ( - attn_output.transpose(1, 2) - .contiguous() - .reshape(bsz, q_len, self.num_heads * self.v_head_dim) - ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, compressed_kvs - def fused_forward( self, hidden_states: torch.Tensor, @@ -850,18 +832,24 @@ def forward_full_kv( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim :] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kva = compressed_kv[:, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, self.kv_lora_rank :] + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + self.kv_b_proj(self.kv_a_layernorm(kva)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_nope = kv[:, :, :, : self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim :] cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) @@ -909,20 +897,26 @@ def forward_full_kv_h_blocking( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim :] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kva = compressed_kv[:, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, self.kv_lora_rank :] + kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + self.kv_b_proj(self.kv_a_layernorm(kva)) .view( bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) # TODO : split this matmul #with k_up and v_up .transpose(1, 2) ) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_nope = kv[:, :, :, : self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim :] cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) @@ -935,7 +929,7 @@ def forward_full_kv_h_blocking( cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - n_head_ckv = 4 # compressed_kv.shape[1] + n_head_ckv = 4 # kva.shape[1] attn_output_list = [] attn_weights_list = [] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index cdb3f596a6..e6c70522af 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1027,6 +1027,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, "DeepseekV3ForCausalLM": { "forward": QEffDeepseekV3ForCausalLM.forward, + "get_submodules_for_export": QEffDeepseekV3ForCausalLM.get_submodules_for_export, }, "DeepseekV3Model": {"forward": QEffDeepseekV3Model.forward, "__qeff_init__": QEffDeepseekV3Model.__qeff_init__}, "DeepseekV3DecoderLayer": { diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index 5c4ee8054c..901484e724 100644 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -166,7 +166,8 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs): if "output_names" in kwargs: kwargs["output_names"] = [ re.sub("_RetainedState", "_InternalRetainedState", name) - if name.endswith("_RetainedState") and ("key" in name or "value" in name) + if name.endswith("_RetainedState") + and ("key" in name or "value" in name or "compressed_kv" in name or "k_pe" in name) else name for name in kwargs["output_names"] ] diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index cb5f54580e..277211ef72 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -12,7 +12,7 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -num_kv_heads_repeat = 4 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. +num_kv_heads_repeat = 4 # When using KIMI_BLOCKING="kv" or "basic", make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True @@ -122,7 +122,7 @@ mxint8_kv_cache=False, num_devices=TS, num_cores=16, - # prefill_only=True, + use_onnx_subfunctions=True, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 6b04a573a557489b571fbde2f744ec91229eba51 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 21 Apr 2026 06:21:41 +0000 Subject: [PATCH 33/51] use generalized kv blocking Signed-off-by: Mamta Singh --- QEfficient/blocking/attention_blocking.py | 47 +++++ .../blocking/blocked_attention_forwards.py | 104 +++++++++++ QEfficient/blocking/blocking_configurator.py | 3 + QEfficient/blocking/get_num_blocks.py | 111 +++++++++++ .../deepseek_v3/modeling_deepseek_qeff.py | 172 ++++-------------- .../transformers/models/pytorch_transforms.py | 2 +- examples/kimi_k2/run_kimik2.py | 14 +- 7 files changed, 311 insertions(+), 142 deletions(-) create mode 100644 QEfficient/blocking/get_num_blocks.py diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 6f19e006dc..5a67f5e10c 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -19,6 +19,7 @@ blocked_h_attention_forward, blocked_hqkv_attention_forward, blocked_kv_attention_forward, + blocked_kv_mla_attention_forward, blocked_q_attention_forward, blocked_qkv_attention_forward, ) @@ -160,3 +161,49 @@ def generic_blocked_attention_interface( ) return attn_output, attn_weights + + +def generic_blocked_mla_attention_interface( + module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + layer_idx: int, + compressed_kvs: torch.Tensor, + enable_absorption: bool, + blocking_config: AttentionBlockingConfig, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_seen_tokens: Optional[int] = None, + non_blocked_forward: Callable = None, + score_mod: Optional[Callable] = None, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + **kwargs, +): + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + attn_output, attn_weights = blocked_kv_mla_attention_forward( + module=module, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + scaling=scaling, + cache_kwargs=cache_kwargs, + layer_idx=layer_idx, + compressed_kvs=compressed_kvs, + enable_absorption=enable_absorption, + num_kv_blocks=blocking_config.num_kv_blocks, + num_q_blocks=blocking_config.num_q_blocks, + head_block_size=blocking_config.head_block_size, + num_batch_blocks=blocking_config.num_batch_blocks, + score_mod=score_mod, + position_bias=position_bias, + sinks=sinks, + ) + + return attn_output, attn_weights diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 83efb8302e..5a75899a05 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -817,3 +817,107 @@ def blocked_q_attention_forward( attn_weights = torch.cat(q_attn_blocks, dim=2) return attn_output, attn_weights + + +def blocked_kv_mla_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + compressed_kvs: Optional[torch.Tensor], + enable_absorption: bool, + *, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize result tensor + batch_size, num_heads, seq_len, _ = query.shape + output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device) + + # Initialize Running Maximum and Denominator + current_max = torch.full( + (batch_size, num_heads, seq_len), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + skip_kv = True + current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device) + + ctx_len = compressed_kvs.layers[0].ckv.shape[2] + kv_block_size = -(-ctx_len // num_kv_blocks) + + position_ids = cache_kwargs.get("position_ids") + current_position = position_ids.max(dim=-1).values + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = ctx_len - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + compressed_kv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, layer_idx, cache_kwargs) + k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, layer_idx, cache_kwargs) + + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=end_index, + start_index=start_index, + ) + + if enable_absorption: + krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) + attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling + # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] + attn_weights_block = torch.where( + causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block + ) + current_max, current_denominator, output = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output, + compressed_kv_block, + skip_kv, + skip_future, + ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] + else: + knope = torch.matmul(compressed_kv_block, key) + krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1) + attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling + attn_weights_block = torch.where( + causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block + ) + current_max, current_denominator, output = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output, + compressed_kv_block, + skip_kv, + skip_future, + ) + + attn_output = torch.matmul(output, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index e741a956e1..c650320327 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, BlockingMode +from QEfficient.blocking.get_num_blocks import get_num_kv_blocks_for_mla from QEfficient.utils import get_attr_or_key, require_value from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD @@ -200,6 +201,8 @@ def build_transformer_blocking_config( int(data_bytes), blocking_mode=blocking_mode, ) + if model_config.model_type == "kimi_k2": + attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len) resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") effective_mode = _resolve_effective_blocking_mode(attention_cfg, resolved_mode) diff --git a/QEfficient/blocking/get_num_blocks.py b/QEfficient/blocking/get_num_blocks.py new file mode 100644 index 0000000000..f97ed8cfeb --- /dev/null +++ b/QEfficient/blocking/get_num_blocks.py @@ -0,0 +1,111 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Compute the maximum kv_block_size under an fp16 memory budget. + +Constraints (bytes) per matmul: +1) [1, num_heads, q_len, 576] x [1, 1, 576, kv] -> [1, num_heads, q_len, kv] +2) [1, num_heads, q_len, kv] x [1, 1, kv, 512] -> [1, num_heads, q_len, 512] + +For each matmul, sum(input_a + input_b + output) must be < budget. +The returned kv_block_size satisfies both constraints. +""" + +from typing import List + +from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD + +FP16_BYTES = 2 +DEFAULT_NUM_HEADS = 64 +VTCM_SIZE_THRESHOLD = int(VTCM_SIZE_THRESHOLD) + + +def matmul1_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: + """Bytes for [1,num_heads,q,kv] x [1,1,kv,512] -> [1,num_heads,q,512] in fp16.""" + elems_a = num_heads * q_len * kv_block_size + elems_b = kv_block_size * 512 + elems_out = num_heads * q_len * 512 + return FP16_BYTES * (elems_a + elems_b + elems_out) + + +def matmul2_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: + """Bytes for [1,num_heads,q,576] x [1,1,576,kv] -> [1,num_heads,q,kv] in fp16.""" + elems_a = num_heads * q_len * 576 + elems_b = 576 * kv_block_size + elems_out = num_heads * q_len * kv_block_size + return FP16_BYTES * (elems_a + elems_b + elems_out) + + +def max_kv_block_size( + q_len: int, + budget_bytes: int = VTCM_SIZE_THRESHOLD, + num_heads: int = DEFAULT_NUM_HEADS, +) -> int: + """Return the largest integer kv_block_size that satisfies both matmul budgets. + + Returns 0 if no positive kv_block_size can satisfy the constraints. + """ + if q_len < 0: + raise ValueError("q_len must be non-negative") + if budget_bytes <= 0: + raise ValueError("budget_bytes must be positive") + if num_heads <= 0: + raise ValueError("num_heads must be positive") + + # Enforce strict inequality in bytes: + # FP16_BYTES * elems < budget_bytes => elems <= floor((budget_bytes - 1)/FP16_BYTES) + max_elems = (budget_bytes - 1) // FP16_BYTES + + # Matmul1 elements: + # A_elems = num_heads*q_len*kv + # B_elems = kv*512 + # C_elems = num_heads*q_len*512 + # Enforce A_elems + B_elems + C_elems <= max_elems + c1_elems = num_heads * q_len * 512 + rem1 = max_elems - c1_elems + den1 = num_heads * q_len + 512 # kv coefficient from A_elems + B_elems + k1 = rem1 // den1 if rem1 >= 0 else -1 + + # Matmul2 elements: + # A_elems = num_heads*q_len*576 + # B_elems = 576*kv + # C_elems = num_heads*q_len*kv + # Enforce A_elems + B_elems + C_elems <= max_elems + a2_elems = num_heads * q_len * 576 + rem2 = max_elems - a2_elems + den2 = num_heads * q_len + 576 # kv coefficient from B_elems + C_elems + k2 = rem2 // den2 if rem2 >= 0 else -1 + + kv = min(k1, k2) + return max(0, kv) + + +def block_candidates_generator(max_length: int) -> List[int]: + block_list = [] + i = 1 + step = 1 + while i <= max_length: + block_list.append(i) + if i % (4 * step) == 0: + step *= 2 + i += step + return block_list + + +def get_num_kv_blocks_for_mla(q_len, num_heads, ctx_len): + budget_bytes = VTCM_SIZE_THRESHOLD + kv = max_kv_block_size(q_len, budget_bytes, num_heads) + b1 = matmul1_bytes(q_len, kv, num_heads) + b2 = matmul2_bytes(q_len, kv, num_heads) + + assert b1 < budget_bytes, "matmul1 is not under the budget" + assert b2 < budget_bytes, "matmul2 is not under the budget" + kv_block_size_list = block_candidates_generator(ctx_len) + for i in range(len(kv_block_size_list) - 1): + if kv_block_size_list[i] < kv < kv_block_size_list[i + 1]: + kv_block_size = kv_block_size_list[i] + return ctx_len // kv_block_size diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index b4f78b46a2..4ba1a8b3a2 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -15,6 +15,10 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from QEfficient.blocking.attention_blocking import ( + AttentionBlockingConfig, + generic_blocked_mla_attention_interface, +) from QEfficient.customop.rms_norm import CustomRMSNormFunc from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -216,54 +220,6 @@ def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -def update_running_softmax( - current_max: torch.Tensor, - attn_weights_block: torch.Tensor, - current_denominator: torch.Tensor, - output: torch.Tensor, - v_block: torch.Tensor, - skip_kv: bool = False, - skip_future: Optional[torch.Tensor] = None, -): - # Update Running row maximum - prev_max = current_max - current_max_updated = torch.max(prev_max, attn_weights_block.max(dim=3).values) - delta_max = prev_max - current_max_updated - - current_exp = torch.exp(attn_weights_block - current_max_updated.unsqueeze(-1)) - - # update running denominator - prev_denominator = current_denominator - curr_exp_sum = torch.einsum("bhqk->bhq", current_exp) - current_denominator_updated = prev_denominator * torch.exp(delta_max) + curr_exp_sum - - prob = current_exp / current_denominator_updated.unsqueeze(-1) - - prev_output = output - # if updating running softmax with attention sinks, we don't have v_block - if v_block is not None: - output_updated = ((prev_denominator / current_denominator_updated).unsqueeze(-1)) * prev_output * torch.exp( - delta_max.unsqueeze(-1) - ) + torch.matmul(prob, v_block) - else: - output_updated = ( - ((prev_denominator / current_denominator_updated).unsqueeze(-1)) - * prev_output - * torch.exp(delta_max.unsqueeze(-1)) - ) - - if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): - current_max = torch.where(skip_future, prev_max, current_max_updated) - current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) - output = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) - else: - # Eager mode - current_max = current_max_updated - current_denominator = current_denominator_updated - output = output_updated - return current_max, current_denominator, output - - class QEffDeepseekV3Attention(nn.Module): """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" @@ -422,10 +378,6 @@ def fused_forward_blocked_kv( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - num_kv_blocks = 4 - - ctx_len = compressed_kvs.layers[0].ckv.shape[2] - kv_block_size = -(-ctx_len // num_kv_blocks) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) @@ -436,18 +388,10 @@ def fused_forward_blocked_kv( q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) q_pe = torch.matmul(q_a_proj_out, self.q_rope) q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.matmul(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) - else: - enable_absorption = False - ## Write Only if compressed_kvs is not None: compressed_kvs.write_only_ckv(kva, self.layer_idx, cache_kwargs) @@ -458,6 +402,13 @@ def fused_forward_blocked_kv( if compressed_kvs is not None: compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) + else: + enable_absorption = False + if enable_absorption: if absorb_online: qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) @@ -465,86 +416,29 @@ def fused_forward_blocked_kv( else: dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk) qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe), dim=-1) - batch_size, num_heads, seq_len, _ = qkupTrope_nope.shape + query = qkupTrope_nope else: - batch_size, num_heads, seq_len, _ = q_nope.shape - - current_position = position_ids.max(dim=-1).values - skip_kv = True - output = torch.zeros(batch_size, self.num_heads, seq_len, self.kv_lora_rank, device=hidden_states.device) - - current_max = torch.full( - (batch_size, num_heads, seq_len), - float(MIN_MASKED_ATTENTION_VALUE), - device=hidden_states.device, + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + qnope_rope = torch.cat((q_nope, q_pe), dim=-1) + query = qnope_rope + + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + query=query, + key=self.per_head_k_up_normal, + value=self.per_head_v_up, + attention_mask=attention_mask, + scaling=self.softmax_scale, + layer_idx=self.layer_idx, + compressed_kvs=compressed_kvs, + enable_absorption=enable_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, ) - current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=hidden_states.device) - - for j in range(num_kv_blocks): - start_index = j * kv_block_size - if j == num_kv_blocks - 1: - kv_len_block = ctx_len - start_index - else: - kv_len_block = kv_block_size - end_index = start_index + kv_len_block - - skip_future = None - if skip_kv: - skip_future = (torch.tensor(start_index, device=hidden_states.device) > current_position).all() - # Eager mode Only - if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): - if skip_future.item(): - break - - compressed_kv_block = compressed_kvs.read_only_blocked_ckv( - start_index, end_index, self.layer_idx, cache_kwargs - ) - k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, self.layer_idx, cache_kwargs) - - causal_mask_block = _create_causal_mask( - position_ids=position_ids, - target_length=end_index, - start_index=start_index, - ) - - if enable_absorption: - krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) - attn_weights_block = ( - torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - ) # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] - attn_weights_block = torch.where( - causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block - ) - current_max, current_denominator, output = update_running_softmax( - current_max, - attn_weights_block, - current_denominator, - output, - compressed_kv_block, - skip_kv, - skip_future, - ) - # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] - else: - knope = torch.matmul(compressed_kv_block, self.per_head_k_up_normal) - krope_nope = torch.cat((knope, k_pe_block.expand(-1, self.num_heads, -1, -1)), dim=-1) - qrope_nope = torch.cat((q_nope, q_pe), dim=-1) - attn_weights_block = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - attn_weights_block = torch.where( - causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block - ) - current_max, current_denominator, output = update_running_softmax( - current_max, - attn_weights_block, - current_denominator, - output, - compressed_kv_block, - skip_kv, - skip_future, - ) - attn_output = torch.matmul(output, self.per_head_v_up) # TODO: merge this matmul with o_proj - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, None, compressed_kvs @@ -708,8 +602,8 @@ def fused_forward_orig( q_a_proj_out.unsqueeze(1), self.fusedqk, ) - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) else: q_nope = torch.bmm(q_a_proj_out, self.q_up) q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index e6c70522af..234ba94033 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1153,7 +1153,7 @@ def apply(cls, model: nn.Module, attn_blocking_config) -> Tuple[nn.Module, bool] if type(module) in cls._skip_classes: warnings.warn(f"Blocking is not yet supported for {type(module)}.") continue - if type(module) in supported_attention_classes: + if type(module) in supported_attention_classes or model.config.model_type == "kimi_k2": module.attn_blocking_config = attn_blocking_config transformed = True elif module.__class__.__name__.endswith("Attention") and type(module) not in supported_attention_classes: diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index 277211ef72..fcc8b7c420 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -5,6 +5,8 @@ # # ---------------------------------------------------------------------------- +import os + import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -12,11 +14,18 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -num_kv_heads_repeat = 4 # When using KIMI_BLOCKING="kv" or "basic", make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. +num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or "basic", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" and this should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True mla_absorption_config = {"enable": False, "online": False} +qaic_config = None + +if os.environ.get("KIMI_BLOCKING", "0") == "kv": + qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} + +if os.environ.get("KIMI_BLOCKING", "0") == "h": + num_kv_heads_repeat = 4 # model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" model_path = ( @@ -81,7 +90,7 @@ prefill_qeff_out = qeff_model.model(**inputs) -breakpoint() +# breakpoint() # assert (prefill_qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 position_ids = inputs["position_ids"] @@ -123,6 +132,7 @@ num_devices=TS, num_cores=16, use_onnx_subfunctions=True, + qaic_config=qaic_config, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 2126672ecceed5ca95d358ab0cbe2ec55768c5de Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 22 Apr 2026 20:50:08 +0000 Subject: [PATCH 34/51] use generalised infra for h_blocking with ckv and full kv Signed-off-by: Mamta Singh --- QEfficient/blocking/attention_blocking.py | 45 ++- .../blocking/blocked_attention_forwards.py | 102 ++++++- QEfficient/blocking/blocking_configurator.py | 2 +- QEfficient/blocking/get_num_blocks.py | 2 - .../deepseek_v3/modeling_deepseek_qeff.py | 258 ++++-------------- .../transformers/models/modeling_auto.py | 6 +- .../transformers/models/pytorch_transforms.py | 7 +- examples/kimi_k2/export_kimik2.py | 22 +- examples/kimi_k2/run_kimik2.py | 11 +- 9 files changed, 216 insertions(+), 239 deletions(-) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 5a67f5e10c..2ab5c03bec 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional import torch from transformers.cache_utils import Cache @@ -17,6 +17,7 @@ from QEfficient.blocking.blocked_attention_forwards import ( blocked_bhqkv_attention_forward, blocked_h_attention_forward, + blocked_h_mla_attention_forward, blocked_hqkv_attention_forward, blocked_kv_attention_forward, blocked_kv_mla_attention_forward, @@ -58,6 +59,11 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: BlockingMode.BHQKV: blocked_bhqkv_attention_forward, } +_STRATEGIES_MLA: Dict[BlockingMode, Callable] = { + BlockingMode.KV: blocked_kv_mla_attention_forward, + BlockingMode.H: blocked_h_mla_attention_forward, +} + # helper function needed both in generic blocked approach and in other modeling files for non-blocked approach def past_key_value_update( @@ -165,15 +171,23 @@ def generic_blocked_attention_interface( def generic_blocked_mla_attention_interface( module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - layer_idx: int, - compressed_kvs: torch.Tensor, - enable_absorption: bool, + mla_absorption: Dict[str, Any], blocking_config: AttentionBlockingConfig, + query: Optional[torch.Tensor] = None, + q_a_proj_out: Optional[torch.Tensor] = None, + fusedqk: Optional[torch.Tensor] = None, + q_nope: Optional[torch.Tensor] = None, + q_pe: Optional[torch.Tensor] = None, + kva: Optional[torch.Tensor] = None, + k_pe: Optional[torch.Tensor] = None, + per_head_q_up: Optional[torch.Tensor] = None, + per_head_k_up: Optional[torch.Tensor] = None, + per_head_v_up: Optional[torch.Tensor] = None, + per_head_k_up_normal: Optional[torch.Tensor] = None, + layer_idx: Optional[int] = None, + compressed_kvs: Optional[torch.Tensor] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -186,17 +200,26 @@ def generic_blocked_mla_attention_interface( **kwargs, ): cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - attn_output, attn_weights = blocked_kv_mla_attention_forward( + mla_blocking_strategy = _STRATEGIES_MLA.get(blocking_config.mode) + attn_output, attn_weights = mla_blocking_strategy( module=module, query=query, - key=key, - value=value, + q_a_proj_out=q_a_proj_out, + fusedqk=fusedqk, + q_nope=q_nope, + q_pe=q_pe, + kva=kva, + k_pe=k_pe, + per_head_q_up=per_head_q_up, + per_head_k_up=per_head_k_up, + per_head_v_up=per_head_v_up, + per_head_k_up_normal=per_head_k_up_normal, attention_mask=attention_mask, scaling=scaling, cache_kwargs=cache_kwargs, layer_idx=layer_idx, compressed_kvs=compressed_kvs, - enable_absorption=enable_absorption, + mla_absorption=mla_absorption, num_kv_blocks=blocking_config.num_kv_blocks, num_q_blocks=blocking_config.num_q_blocks, head_block_size=blocking_config.head_block_size, diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 5a75899a05..5e4ed12f50 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -822,15 +822,15 @@ def blocked_q_attention_forward( def blocked_kv_mla_attention_forward( module: nn.Module, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + per_head_k_up_normal: torch.Tensor, + per_head_v_up: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, num_kv_blocks: int, cache_kwargs: Dict[str, Any], layer_idx: int, compressed_kvs: Optional[torch.Tensor], - enable_absorption: bool, + mla_absorption: Dict[str, Any], *, use_causal_mask: bool = False, sliding_window: Optional[int] = None, @@ -843,6 +843,12 @@ def blocked_kv_mla_attention_forward( batch_size, num_heads, seq_len, _ = query.shape output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device) + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = query.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) + # Initialize Running Maximum and Denominator current_max = torch.full( (batch_size, num_heads, seq_len), @@ -883,13 +889,13 @@ def blocked_kv_mla_attention_forward( start_index=start_index, ) + enable_absorption = mla_absorption.get("enable", False) + if enable_absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] - attn_weights_block = torch.where( - causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block - ) + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) current_max, current_denominator, output = update_running_softmax( current_max, attn_weights_block, @@ -900,12 +906,10 @@ def blocked_kv_mla_attention_forward( skip_future, ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] else: - knope = torch.matmul(compressed_kv_block, key) + knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling - attn_weights_block = torch.where( - causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block - ) + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) current_max, current_denominator, output = update_running_softmax( current_max, attn_weights_block, @@ -916,8 +920,84 @@ def blocked_kv_mla_attention_forward( skip_future, ) - attn_output = torch.matmul(output, value) + attn_output = torch.matmul(output, per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous() attn_weights = None return attn_output, attn_weights + + +def blocked_h_mla_attention_forward( + module: nn.Module, + q_a_proj_out: torch.Tensor, + fusedqk: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kva: torch.Tensor, + k_pe: torch.Tensor, + per_head_q_up: torch.Tensor, + per_head_k_up: torch.Tensor, + per_head_v_up: torch.Tensor, + per_head_k_up_normal: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + mla_absorption: Dict[str, Any], + head_block_size: int, + *, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + H-blocked attention that slices along head dimension to create blocks and processes each block. + """ + batch_size, num_heads, q_len, _ = q_pe.shape + if head_block_size <= 0: + head_block_size = num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = q_pe.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=q_pe.device) + + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) + else: + enable_absorption = False + + h_output_blocks = [] + h_attn_blocks = [] + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, num_heads) + + if enable_absorption: + if absorb_online: + qup_kupT = torch.matmul(per_head_q_up[:, h_start:h_end, :, :], per_head_k_up[:, h_start:h_end, :, :]) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, fusedqk[:, h_start:h_end, :, :]) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) + krope_nope = torch.cat((kva, k_pe), dim=-1) + attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * scaling + else: + knope = torch.matmul(kva, per_head_k_up_normal[:, h_start:h_end, :, :]) + krope_nope = torch.cat((knope, k_pe), dim=-1) + qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) + attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, masked_tensor, attn_weights) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, per_head_v_up[:, h_start:h_end, :, :]) + h_output_blocks.append(attn_output) + h_attn_blocks.append(attn_weights) + + attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() + attn_weights = torch.cat(h_attn_blocks, dim=1) + return attn_output, attn_weights diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index c650320327..eee3257917 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -201,7 +201,7 @@ def build_transformer_blocking_config( int(data_bytes), blocking_mode=blocking_mode, ) - if model_config.model_type == "kimi_k2": + if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len) resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") diff --git a/QEfficient/blocking/get_num_blocks.py b/QEfficient/blocking/get_num_blocks.py index f97ed8cfeb..519acd59cb 100644 --- a/QEfficient/blocking/get_num_blocks.py +++ b/QEfficient/blocking/get_num_blocks.py @@ -21,7 +21,6 @@ FP16_BYTES = 2 DEFAULT_NUM_HEADS = 64 -VTCM_SIZE_THRESHOLD = int(VTCM_SIZE_THRESHOLD) def matmul1_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: @@ -101,7 +100,6 @@ def get_num_kv_blocks_for_mla(q_len, num_heads, ctx_len): kv = max_kv_block_size(q_len, budget_bytes, num_heads) b1 = matmul1_bytes(q_len, kv, num_heads) b2 = matmul2_bytes(q_len, kv, num_heads) - assert b1 < budget_bytes, "matmul1 is not under the budget" assert b2 < budget_bytes, "matmul2 is not under the budget" kv_block_size_list = block_candidates_generator(ctx_len) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py index 4ba1a8b3a2..c4688edfee 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py @@ -17,6 +17,7 @@ from QEfficient.blocking.attention_blocking import ( AttentionBlockingConfig, + generic_blocked_attention_interface, generic_blocked_mla_attention_interface, ) from QEfficient.customop.rms_norm import CustomRMSNormFunc @@ -255,12 +256,9 @@ def __qeff_init__( per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) self.per_head_k_up_normal = torch.nn.Parameter(per_head_k_up_normal.detach().clone()) - fusedqk_list = [] - for i in range(self.num_heads): - fusedqk_list.append(torch.matmul(per_head_q_up[i, :, :], per_head_k_up[i, :, :])) - fusedqk = torch.cat(fusedqk_list, dim=0) - fusedqk = fusedqk.reshape(1, self.num_heads, -1, self.kv_lora_rank) - + fusedqk = torch.bmm(per_head_q_up, per_head_k_up).reshape( + -1, self.num_heads, self.q_lora_rank, self.kv_lora_rank + ) self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) def fused_forward_h_blocking( @@ -297,72 +295,40 @@ def fused_forward_h_blocking( if compressed_kvs is not None: kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) - if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) - else: - enable_absorption = False - - n_head_ckv = kva.shape[1] - cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - attn_output_list = [] - attn_weights_list = [] - for head_block_idx in range(self.num_heads // n_head_ckv): - h_start = head_block_idx * n_head_ckv - h_end = min(h_start + n_head_ckv, self.num_heads) - - if enable_absorption: - if absorb_online: - qup_kupT = torch.matmul( - self.per_head_q_up[:, h_start:h_end, :, :], self.per_head_k_up[:, h_start:h_end, :, :] - ) - dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) - else: - dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk[:, h_start:h_end, :, :]) - qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) - krope_nope = torch.cat((kva, k_pe), dim=-1) - attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qkupTrope_nope.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up[:, h_start:h_end, :, :]) - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - else: - knope = torch.matmul(kva, self.per_head_k_up_normal[:, h_start:h_end, :, :]) - krope_nope = torch.cat((knope, k_pe), dim=-1) - qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) - attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(qrope_nope.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul( - attn_output, self.per_head_v_up[:, h_start:h_end, :, :] - ) # TODO: merge this matmul with o_proj - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) - - attn_output = torch.cat(attn_output_list, dim=1) - attn_weights = torch.cat(attn_weights_list, dim=1) - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) # 7168, 8192 + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + q_a_proj_out=q_a_proj_out, + fusedqk=self.fusedqk, + q_nope=q_nope, + q_pe=q_pe, + kva=kva, + k_pe=k_pe, + per_head_q_up=self.per_head_q_up, + per_head_k_up=self.per_head_k_up, + per_head_v_up=self.per_head_v_up, + per_head_k_up_normal=self.per_head_k_up_normal, + attention_mask=attention_mask, + scaling=self.softmax_scale, + mla_absorption=mla_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) return attn_output, attn_weights, compressed_kvs - def fused_forward_blocked_kv( + def fused_forward_kv_blocking( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], @@ -402,7 +368,6 @@ def fused_forward_blocked_kv( if compressed_kvs is not None: compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) absorb_online = mla_absorption.get("online", False) @@ -423,16 +388,18 @@ def fused_forward_blocked_kv( qnope_rope = torch.cat((q_nope, q_pe), dim=-1) query = qnope_rope + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + attn_output, attn_weights = generic_blocked_mla_attention_interface( module=self, query=query, - key=self.per_head_k_up_normal, - value=self.per_head_v_up, + per_head_k_up_normal=self.per_head_k_up_normal, + per_head_v_up=self.per_head_v_up, attention_mask=attention_mask, scaling=self.softmax_scale, layer_idx=self.layer_idx, compressed_kvs=compressed_kvs, - enable_absorption=enable_absorption, + mla_absorption=mla_absorption, blocking_config=blocking_config, position_ids=position_ids, **kwargs, @@ -443,87 +410,6 @@ def fused_forward_blocked_kv( return attn_output, None, compressed_kvs - def fused_forward_basic( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - compressed_kvs: Optional[torch.Tensor] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - mla_absorption: Optional[Dict[str, bool]] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - - kva = compressed_kv[:, :, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] - - q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.bmm(q_a_proj_out, self.q_rope) - q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - - kva = self.kv_a_layernorm(kva) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if compressed_kvs is not None: - kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) - - if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) - else: - enable_absorption = False - - cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - - if enable_absorption: - if absorb_online: - out = torch.matmul(self.per_head_q_up, self.per_head_k_up) - q_nope_compressed = torch.matmul(q_a_proj_out, out) - else: - q_nope_compressed = torch.matmul(q_a_proj_out, self.fusedqk) - - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) - key_states = torch.cat((kva, k_pe), dim=-1) - else: - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - query_states = torch.cat((q_nope, q_pe), dim=-1) - - k_up_per_head = ( - self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) - ) - k_nope = torch.matmul(kva, k_up_per_head) - key_states = torch.cat((k_nope, k_pe.expand(-1, self.num_heads, -1, -1)), dim=-1) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - attn_output = torch.matmul(attn_weights, kva) - attn_output = torch.matmul(attn_output, self.per_head_v_up) - - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, compressed_kvs - def fused_forward_orig( self, hidden_states: torch.Tensor, @@ -561,8 +447,6 @@ def fused_forward_orig( if compressed_kvs is not None: kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) - # kva = compressed_kv - # ---- MLA absorption flags ---- if mla_absorption is not None: enable_absorption = mla_absorption.get("enable", False) @@ -570,8 +454,8 @@ def fused_forward_orig( else: enable_absorption = False - n_head_ckv = kva.shape[1] - p = self.num_heads // n_head_ckv + head_block_size = kva.shape[1] + p = self.num_heads // head_block_size seq_kv = kva.shape[2] # ---- Rotary ---- @@ -590,7 +474,6 @@ def fused_forward_orig( ) v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) - value_states = torch.matmul(kva_expanded, v_up_per_head) if enable_absorption: @@ -662,22 +545,7 @@ def fused_forward( **kwargs, ) elif os.environ.get("KIMI_BLOCKING", "0") == "kv": - return self.fused_forward_blocked_kv( - hidden_states, - position_embeddings, - attention_mask, - position_ids, - past_key_value, - compressed_kvs, - batch_index, - output_attentions, - use_cache, - cache_position, - mla_absorption, - **kwargs, - ) - elif os.environ.get("KIMI_BLOCKING", "0") == "basic": - return self.fused_forward_basic( + return self.fused_forward_kv_blocking( hidden_states, position_embeddings, attention_mask, @@ -770,7 +638,7 @@ def forward_full_kv( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value, value_states + return attn_output, attn_weights, past_key_value def forward_full_kv_h_blocking( self, @@ -798,8 +666,8 @@ def forward_full_kv_h_blocking( compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - kva = compressed_kv[:, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, self.kv_lora_rank :] + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] kv = ( self.kv_b_proj(self.kv_a_layernorm(kva)) @@ -819,39 +687,25 @@ def forward_full_kv_h_blocking( k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) key_states = torch.cat((k_nope, k_pe_new), -1) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - n_head_ckv = 4 # kva.shape[1] - - attn_output_list = [] - attn_weights_list = [] - for head_block_idx in range(self.num_heads // n_head_ckv): - h_start = head_block_idx * n_head_ckv - h_end = min(h_start + n_head_ckv, self.num_heads) - - attn_weights = ( - torch.matmul(query_states[:, h_start:h_end, :, :], key_states[:, h_start:h_end, :, :].transpose(2, 3)) - * self.softmax_scale - ) - - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states[:, h_start:h_end, :, :]) - attn_output_list.append(attn_output) - attn_weights_list.append(attn_weights) - - attn_output = torch.cat(attn_output_list, dim=1) - attn_weights = torch.cat(attn_weights_list, dim=1) + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.softmax_scale, + layer_idx=self.layer_idx, + past_key_value=past_key_value, + blocking_config=blocking_config, + batch_index=batch_index, + position_ids=position_ids, + ) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value, value_states + return attn_output, attn_weights, past_key_value def forward( self, @@ -1075,7 +929,7 @@ def forward( **kwargs, ) else: - hidden_states, self_attn_weights, present_key_value, vs = self.self_attn( + hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=orig_hidden_states, attention_mask=attention_mask, position_ids=position_ids, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2918e7728b..16f038665f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2871,7 +2871,7 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached - if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: + if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) if replicate_kv_transformed: self.hash_params["config"] = model.config.to_diff_dict() @@ -3093,7 +3093,7 @@ def export( self.hash_params["mla_absorption_config"] = mla_absorption_config setattr(self.model.model, "mla_absorption_config", mla_absorption_config) - if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: + if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): if prefill_only: self.prefill(enable=True) self.hash_params["prefill_only"] = True @@ -3203,7 +3203,7 @@ def export( dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") - if self.model.config.model_type in {"kimi_k2", "kimi_k25"}: + if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): if enable_mla: for lay in self.model.model.layers: if lay is not None: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 234ba94033..2cdb107956 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1043,8 +1043,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "forward_full_kv": QEffDeepseekV3Attention.forward_full_kv, "forward_full_kv_h_blocking": QEffDeepseekV3Attention.forward_full_kv_h_blocking, "fused_forward": QEffDeepseekV3Attention.fused_forward, - "fused_forward_blocked_kv": QEffDeepseekV3Attention.fused_forward_blocked_kv, - "fused_forward_basic": QEffDeepseekV3Attention.fused_forward_basic, + "fused_forward_kv_blocking": QEffDeepseekV3Attention.fused_forward_kv_blocking, "fused_forward_orig": QEffDeepseekV3Attention.fused_forward_orig, "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, @@ -1153,7 +1152,9 @@ def apply(cls, model: nn.Module, attn_blocking_config) -> Tuple[nn.Module, bool] if type(module) in cls._skip_classes: warnings.warn(f"Blocking is not yet supported for {type(module)}.") continue - if type(module) in supported_attention_classes or model.config.model_type == "kimi_k2": + if type(module) in supported_attention_classes or "DeepseekV3ForCausalLM" in ( + getattr(model.config, "architectures", None) or [] + ): module.attn_blocking_config = attn_blocking_config transformed = True elif module.__class__.__name__.endswith("Attention") and type(module) not in supported_attention_classes: diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 6b9cc969ed..554a9a6c55 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -5,6 +5,8 @@ # # ---------------------------------------------------------------------------- +import os + import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -13,11 +15,23 @@ # parameters to be configured prompt = "Once upon a time," -num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or None, make sure this is set to 1. Use only for KIMI_BLOCKING="h" and this number should be equal to TS in that case. +num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" when using MLA and this should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True -mla_absorption_config = {"enable": True, "online": True} +mla_absorption_config = {"enable": False, "online": False} +qaic_config = None + +if os.environ.get("KIMI_BLOCKING", "0") == "kv": + qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} + num_kv_heads_repeat = 1 + +if os.environ.get("KIMI_BLOCKING", "0") == "h": + if enable_mla: + num_kv_heads_repeat = TS + else: + num_kv_heads_repeat = 1 # head replication is not needed if MLA is not enabled + qaic_config = {"enable_blocking": True, "blocking_mode": "h", "head_block_size": TS} model = AutoModelForCausalLM.from_pretrained( "moonshotai/Kimi-K2-Thinking", @@ -27,7 +41,7 @@ ) tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) -qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat, qaic_config=qaic_config) qpc_path = qeff_model.compile( prefill_seq_len=1, @@ -38,6 +52,8 @@ mxint8_kv_cache=False, num_devices=TS, num_cores=16, + # use_onnx_subfunctions=True, + qaic_config=qaic_config, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index fcc8b7c420..6724802e83 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -14,7 +14,7 @@ from QEfficient import QEFFAutoModelForCausalLM prompt = "Once upon a time," -num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv" or "basic", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" and this should be equal to TS in that case. +num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" when using MLA and this should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True @@ -23,9 +23,14 @@ if os.environ.get("KIMI_BLOCKING", "0") == "kv": qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} + num_kv_heads_repeat = 1 if os.environ.get("KIMI_BLOCKING", "0") == "h": - num_kv_heads_repeat = 4 + if enable_mla: + num_kv_heads_repeat = TS + else: + num_kv_heads_repeat = 1 # head replication is not needed if MLA is not enabled + qaic_config = {"enable_blocking": True, "blocking_mode": "h", "head_block_size": TS} # model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" model_path = ( @@ -50,7 +55,7 @@ # out = model(**inputs) # predictions = torch.argmax(out.logits, dim=-1) -qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat) +qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat, qaic_config=qaic_config) qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) From c9fd13e2de5476b765e9440d44f9d8f74a558047 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Thu, 23 Apr 2026 21:55:05 +0530 Subject: [PATCH 35/51] remove redundant files Signed-off-by: Mamta Singh --- QEfficient/blocking/blocking_configurator.py | 86 +- QEfficient/blocking/get_num_blocks.py | 109 - .../models/deepseek_v3/modeling_deepseek.py | 1980 +++++++---------- .../deepseek_v3/modeling_deepseek_orig.py | 1653 -------------- .../deepseek_v3/modeling_deepseek_qeff.py | 1162 ---------- .../transformers/models/pytorch_transforms.py | 2 +- examples/kimi_k2/export_kimik2.py | 59 - examples/kimi_k2/run_kimik2.py | 2 + examples/kimi_k2/run_orig_kimi_k2.py | 36 - 9 files changed, 833 insertions(+), 4256 deletions(-) delete mode 100644 QEfficient/blocking/get_num_blocks.py mode change 100755 => 100644 QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py delete mode 100644 QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py delete mode 100644 QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py delete mode 100644 examples/kimi_k2/export_kimik2.py delete mode 100644 examples/kimi_k2/run_orig_kimi_k2.py diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index eee3257917..54e09c205a 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -17,10 +17,11 @@ from typing import Any, Dict, List, Optional from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, BlockingMode -from QEfficient.blocking.get_num_blocks import get_num_kv_blocks_for_mla from QEfficient.utils import get_attr_or_key, require_value from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD +FP16_BYTES = 2 +DEFAULT_NUM_HEADS = 64 def _infer_head_dim(model_config: Any, num_heads: int) -> int: head_dim = get_attr_or_key(model_config, ("attention_head_dim", "head_dim", "head_dim_per_head")) @@ -91,6 +92,89 @@ def block_candidates_generator(max_length: int) -> List[int]: return block_list +def matmul1_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: + """Bytes for [1,num_heads,q,kv] x [1,1,kv,512] -> [1,num_heads,q,512] in fp16.""" + elems_a = num_heads * q_len * kv_block_size + elems_b = kv_block_size * 512 + elems_out = num_heads * q_len * 512 + return FP16_BYTES * (elems_a + elems_b + elems_out) + + +def matmul2_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: + """Bytes for [1,num_heads,q,576] x [1,1,576,kv] -> [1,num_heads,q,kv] in fp16.""" + elems_a = num_heads * q_len * 576 + elems_b = 576 * kv_block_size + elems_out = num_heads * q_len * kv_block_size + return FP16_BYTES * (elems_a + elems_b + elems_out) + + +def max_kv_block_size( + q_len: int, + budget_bytes: int = VTCM_SIZE_THRESHOLD, + num_heads: int = DEFAULT_NUM_HEADS, +) -> int: + """Return the largest integer kv_block_size that satisfies both matmul budgets. + + Returns 0 if no positive kv_block_size can satisfy the constraints. + """ + if q_len < 0: + raise ValueError("q_len must be non-negative") + if budget_bytes <= 0: + raise ValueError("budget_bytes must be positive") + if num_heads <= 0: + raise ValueError("num_heads must be positive") + + # Enforce strict inequality in bytes: + # FP16_BYTES * elems < budget_bytes => elems <= floor((budget_bytes - 1)/FP16_BYTES) + max_elems = (budget_bytes - 1) // FP16_BYTES + + # Matmul1 elements: + # A_elems = num_heads*q_len*kv + # B_elems = kv*512 + # C_elems = num_heads*q_len*512 + # Enforce A_elems + B_elems + C_elems <= max_elems + c1_elems = num_heads * q_len * 512 + rem1 = max_elems - c1_elems + den1 = num_heads * q_len + 512 # kv coefficient from A_elems + B_elems + k1 = rem1 // den1 if rem1 >= 0 else -1 + + # Matmul2 elements: + # A_elems = num_heads*q_len*576 + # B_elems = 576*kv + # C_elems = num_heads*q_len*kv + # Enforce A_elems + B_elems + C_elems <= max_elems + a2_elems = num_heads * q_len * 576 + rem2 = max_elems - a2_elems + den2 = num_heads * q_len + 576 # kv coefficient from B_elems + C_elems + k2 = rem2 // den2 if rem2 >= 0 else -1 + + kv = min(k1, k2) + return max(0, kv) + + +def get_num_kv_blocks_for_mla(q_len, num_heads, ctx_len): + """Compute the maximum kv_block_size under an fp16 memory budget. + + Constraints (bytes) per matmul: + 1) [1, num_heads, q_len, 576] x [1, 1, 576, kv] -> [1, num_heads, q_len, kv] + 2) [1, num_heads, q_len, kv] x [1, 1, kv, 512] -> [1, num_heads, q_len, 512] + + For each matmul, sum(input_a + input_b + output) must be < budget. + The returned kv_block_size satisfies both constraints. + """ + budget_bytes = VTCM_SIZE_THRESHOLD + kv = max_kv_block_size(q_len, budget_bytes, num_heads) + b1 = matmul1_bytes(q_len, kv, num_heads) + b2 = matmul2_bytes(q_len, kv, num_heads) + assert b1 < budget_bytes, "matmul1 is not under the budget" + assert b2 < budget_bytes, "matmul2 is not under the budget" + kv_block_size_list = block_candidates_generator(ctx_len) + for i in range(len(kv_block_size_list) - 1): + if kv_block_size_list[i] < kv < kv_block_size_list[i + 1]: + kv_block_size = kv_block_size_list[i] + return ctx_len // kv_block_size + + def attention_configurator( bs: int, seq_len: int, diff --git a/QEfficient/blocking/get_num_blocks.py b/QEfficient/blocking/get_num_blocks.py deleted file mode 100644 index 519acd59cb..0000000000 --- a/QEfficient/blocking/get_num_blocks.py +++ /dev/null @@ -1,109 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -"""Compute the maximum kv_block_size under an fp16 memory budget. - -Constraints (bytes) per matmul: -1) [1, num_heads, q_len, 576] x [1, 1, 576, kv] -> [1, num_heads, q_len, kv] -2) [1, num_heads, q_len, kv] x [1, 1, kv, 512] -> [1, num_heads, q_len, 512] - -For each matmul, sum(input_a + input_b + output) must be < budget. -The returned kv_block_size satisfies both constraints. -""" - -from typing import List - -from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD - -FP16_BYTES = 2 -DEFAULT_NUM_HEADS = 64 - - -def matmul1_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: - """Bytes for [1,num_heads,q,kv] x [1,1,kv,512] -> [1,num_heads,q,512] in fp16.""" - elems_a = num_heads * q_len * kv_block_size - elems_b = kv_block_size * 512 - elems_out = num_heads * q_len * 512 - return FP16_BYTES * (elems_a + elems_b + elems_out) - - -def matmul2_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: - """Bytes for [1,num_heads,q,576] x [1,1,576,kv] -> [1,num_heads,q,kv] in fp16.""" - elems_a = num_heads * q_len * 576 - elems_b = 576 * kv_block_size - elems_out = num_heads * q_len * kv_block_size - return FP16_BYTES * (elems_a + elems_b + elems_out) - - -def max_kv_block_size( - q_len: int, - budget_bytes: int = VTCM_SIZE_THRESHOLD, - num_heads: int = DEFAULT_NUM_HEADS, -) -> int: - """Return the largest integer kv_block_size that satisfies both matmul budgets. - - Returns 0 if no positive kv_block_size can satisfy the constraints. - """ - if q_len < 0: - raise ValueError("q_len must be non-negative") - if budget_bytes <= 0: - raise ValueError("budget_bytes must be positive") - if num_heads <= 0: - raise ValueError("num_heads must be positive") - - # Enforce strict inequality in bytes: - # FP16_BYTES * elems < budget_bytes => elems <= floor((budget_bytes - 1)/FP16_BYTES) - max_elems = (budget_bytes - 1) // FP16_BYTES - - # Matmul1 elements: - # A_elems = num_heads*q_len*kv - # B_elems = kv*512 - # C_elems = num_heads*q_len*512 - # Enforce A_elems + B_elems + C_elems <= max_elems - c1_elems = num_heads * q_len * 512 - rem1 = max_elems - c1_elems - den1 = num_heads * q_len + 512 # kv coefficient from A_elems + B_elems - k1 = rem1 // den1 if rem1 >= 0 else -1 - - # Matmul2 elements: - # A_elems = num_heads*q_len*576 - # B_elems = 576*kv - # C_elems = num_heads*q_len*kv - # Enforce A_elems + B_elems + C_elems <= max_elems - a2_elems = num_heads * q_len * 576 - rem2 = max_elems - a2_elems - den2 = num_heads * q_len + 576 # kv coefficient from B_elems + C_elems - k2 = rem2 // den2 if rem2 >= 0 else -1 - - kv = min(k1, k2) - return max(0, kv) - - -def block_candidates_generator(max_length: int) -> List[int]: - block_list = [] - i = 1 - step = 1 - while i <= max_length: - block_list.append(i) - if i % (4 * step) == 0: - step *= 2 - i += step - return block_list - - -def get_num_kv_blocks_for_mla(q_len, num_heads, ctx_len): - budget_bytes = VTCM_SIZE_THRESHOLD - kv = max_kv_block_size(q_len, budget_bytes, num_heads) - b1 = matmul1_bytes(q_len, kv, num_heads) - b2 = matmul2_bytes(q_len, kv, num_heads) - assert b1 < budget_bytes, "matmul1 is not under the budget" - assert b2 < budget_bytes, "matmul2 is not under the budget" - kv_block_size_list = block_candidates_generator(ctx_len) - for i in range(len(kv_block_size_list) - 1): - if kv_block_size_list[i] < kv < kv_block_size_list[i + 1]: - kv_block_size = kv_block_size_list[i] - return ctx_len // kv_block_size diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py old mode 100755 new mode 100644 index 378d5577fc..c4688edfee --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -6,92 +6,78 @@ # ---------------------------------------------------------------------------- import math -import warnings -from typing import List, Optional, Tuple, Union +import os +from typing import Dict, List, Optional, Tuple, Type, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ( - ALL_LAYERNORM_LAYERS, - is_torch_greater_or_equal_than_1_13, -) -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from QEfficient.blocking.attention_blocking import ( + AttentionBlockingConfig, + generic_blocked_attention_interface, + generic_blocked_mla_attention_interface, ) -from transformers.utils.import_utils import is_torch_fx_available +from QEfficient.customop.rms_norm import CustomRMSNormFunc +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE -from .configuration_deepseek import DeepseekV3Config -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case -logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "DeepseekV3Config" +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func -class DeepseekV3RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - DeepseekV3RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + +class QEffDeepseekV3CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + """ + Forward pass of the RMSNorm module. + Args: + hidden_states (torch.Tensor): Input tensor to be normalized. -ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) + Returns: + torch.Tensor: Normalized tensor. + """ + return CustomRMSNormFunc.apply( + hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + ) class DeepseekV3RotaryEmbedding(nn.Module): @@ -132,93 +118,15 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len].to(dtype=x.dtype), ) + # def forward(self, x, position_ids): + # seq_len = torch.max(position_ids) + 1 + # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 -class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): - """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 -class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): - """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -# Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - -# Find dim range bounds based on rotations -def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -def yarn_linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func + # # Use position_ids to slice the precomputed caches + # cos = self.cos_cached[position_ids] + # sin = self.sin_cached[position_ids] + # return cos.to(x.dtype), sin.to(x.dtype) class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): @@ -277,16 +185,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -321,982 +221,821 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -class DeepseekV3MLP(nn.Module): - def __init__(self, config, hidden_size=None, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size +class QEffDeepseekV3Attention(nn.Module): + """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] + def __qeff_init__( + self, + ): + q_up, q_rope = self.q_b_proj.weight.T.view( + -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim + ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + self.q_up = torch.nn.Parameter(q_up.detach().clone()) + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) + self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) -class MoEGate(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method - self.n_group = config.n_group - self.topk_group = config.topk_group - - # topk selection algorithm - self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) - if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) - self.reset_parameters() - - def reset_parameters(self) -> None: - import torch.nn.init as init - - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) + + self.k_up = torch.nn.Parameter(k_up.detach()) + self.v_up = torch.nn.Parameter(v_up.detach()) + per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + per_head_k_up = ( + self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) + ) + per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) + self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) + self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) + per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) + self.per_head_k_up_normal = torch.nn.Parameter(per_head_k_up_normal.detach().clone()) + + fusedqk = torch.bmm(per_head_q_up, per_head_k_up).reshape( + -1, self.num_heads, self.q_lora_rank, self.kv_lora_rank + ) + self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - ### compute gating score - hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) - if self.scoring_func == "sigmoid": - scores = logits.sigmoid() - else: - raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") - - ### select top-k experts - if self.topk_method == "noaux_tc": - assert not self.training - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) - else: - raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") + def fused_forward_h_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(kva) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) + + cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + q_a_proj_out=q_a_proj_out, + fusedqk=self.fusedqk, + q_nope=q_nope, + q_pe=q_pe, + kva=kva, + k_pe=k_pe, + per_head_q_up=self.per_head_q_up, + per_head_k_up=self.per_head_k_up, + per_head_v_up=self.per_head_v_up, + per_head_k_up_normal=self.per_head_k_up_normal, + attention_mask=attention_mask, + scaling=self.softmax_scale, + mla_absorption=mla_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, + ) - return topk_idx, topk_weight + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, compressed_kvs -class DeepseekV3MoE(nn.Module): - """ - A mixed expert module containing shared experts. - """ + def fused_forward_kv_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() - def __init__(self, config): - super().__init__() - self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - - if hasattr(config, "ep_size") and config.ep_size > 1: - assert config.ep_size == dist.get_world_size() - self.ep_size = config.ep_size - self.experts_per_rank = config.n_routed_experts // config.ep_size - self.ep_rank = dist.get_rank() - self.experts = nn.ModuleList( - [ - ( - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank - else None - ) - for i in range(config.n_routed_experts) - ] - ) + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(kva) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + + ## Write Only + if compressed_kvs is not None: + compressed_kvs.write_only_ckv(kva, self.layer_idx, cache_kwargs) + + cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) + + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) else: - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - for i in range(config.n_routed_experts) - ] - ) - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) + enable_absorption = False - def forward(self, hidden_states): - identity = hidden_states - orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y - - @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - sorted_tokens_shape = sorted_tokens.shape - if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() - gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] - ) - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() - dist.all_to_all( - list(gathered_tokens.split(output_splits)), - list(sorted_tokens.split(input_split_sizes)), - ) - tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k - gatherd_idxs = gatherd_idxs.argsort() - sorted_tokens = gathered_tokens[gatherd_idxs] - tokens_per_expert = tokens_per_expert_post_gather - tokens_per_expert = tokens_per_expert.cpu().numpy() - - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens - if num_tokens == 0: - continue - expert = self.experts[i + self.ep_rank * self.experts_per_rank] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out) - start_idx = end_idx - - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - if self.ep_size > 1: - new_x = torch.empty_like(outs) - new_x[gatherd_idxs] = outs - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) - dist.all_to_all( - list(gathered_tokens.split(input_split_sizes)), - list(new_x.split(output_splits)), - ) - outs = gathered_tokens - - new_x = torch.empty_like(outs) - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) + if enable_absorption: + if absorb_online: + qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe), dim=-1) + query = qkupTrope_nope + else: + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + qnope_rope = torch.cat((q_nope, q_pe), dim=-1) + query = qnope_rope + + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + query=query, + per_head_k_up_normal=self.per_head_k_up_normal, + per_head_v_up=self.per_head_v_up, + attention_mask=attention_mask, + scaling=self.softmax_scale, + layer_idx=self.layer_idx, + compressed_kvs=compressed_kvs, + mla_absorption=mla_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, ) - return final_out + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return attn_output, None, compressed_kvs + def fused_forward_orig( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + print("using orig forward") -# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 -class DeepseekV3Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" + # ---- KV compression ---- + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads + # ---- Q projections ---- + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - self.is_causal = True + kva = self.kv_a_layernorm(kva) - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) + + # ---- MLA absorption flags ---- + if mla_absorption is not None: + enable_absorption = mla_absorption.get("enable", False) + absorb_online = mla_absorption.get("online", False) else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) - - self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - config.kv_lora_rank + config.qk_rope_head_dim, - bias=config.attention_bias, - ) - self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) - self.kv_b_proj = nn.Linear( - config.kv_lora_rank, - self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), - bias=False, + enable_absorption = False + + head_block_size = kva.shape[1] + p = self.num_heads // head_block_size + seq_kv = kva.shape[2] + + # ---- Rotary ---- + cos, sin = self.rotary_emb(q_pe, seq_len=32 * 1024) # Doesn't need q_pe as head_dim is initialized + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + + kva_expanded = ( + kva.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) ) - self.o_proj = nn.Linear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=config.attention_bias, + k_pe_expanded = ( + k_pe.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.qk_rope_head_dim) ) - self._init_rope() - - self.softmax_scale = self.q_head_dim ** (-0.5) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = DeepseekV3RotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "yarn": - kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rotary_emb = DeepseekV3YarnRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **kwargs, - ) + + v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) + value_states = torch.matmul(kva_expanded, v_up_per_head) + + if enable_absorption: + if absorb_online: + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + q_nope_compressed = torch.matmul( + q_a_proj_out.unsqueeze(1), + self.fusedqk, + ) + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + else: + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() + k_up_per_head = ( + self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) + ) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) - def forward( + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype), + attn_weights, + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + ## Do v_proj here + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, compressed_kvs + + def fused_forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + if os.environ.get("KIMI_BLOCKING", "0") == "h": + return self.fused_forward_h_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + elif os.environ.get("KIMI_BLOCKING", "0") == "kv": + return self.fused_forward_kv_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + else: + return self.fused_forward_orig( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, ) - bsz, q_len, _ = hidden_states.size() + def forward_full_kv( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim :] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kva = compressed_kv[:, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, self.kv_lora_rank :] + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + self.kv_b_proj(self.kv_a_layernorm(kva)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + k_nope = kv[:, :, :, : self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim :] - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) + key_states = torch.cat((k_nope, k_pe_new), -1) - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 -class DeepseekV3FlashAttention2(DeepseekV3Attention): - """ - DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( + def forward_full_kv_h_blocking( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # DeepseekV3FlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - output_attentions = False - bsz, q_len, _ = hidden_states.size() - if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape + q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim :] + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + self.kv_b_proj(self.kv_a_layernorm(kva)) + .view( + bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) # TODO : split this matmul #with k_up and v_up .transpose(1, 2) ) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] + k_nope = kv[:, :, :, : self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim :] - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - if self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) + key_states = torch.cat((k_nope, k_pe_new), -1) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DeepseekV3RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.softmax_scale, + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.softmax_scale, + layer_idx=self.layer_idx, + past_key_value=past_key_value, + blocking_config=blocking_config, + batch_index=batch_index, + position_ids=position_ids, ) - if self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - def _flash_attention_forward( + def forward( self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - ( - query_states, - key_states, - value_states, - indices_q, - cu_seq_lens, - max_seq_lens, - ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if os.environ.get("KIMI_BLOCKING", "0") == "h": + return self.forward_full_kv_h_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + batch_index, + output_attentions, + use_cache, + cache_position, + **kwargs, ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, + return self.forward_full_kv( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + batch_index, + output_attentions, + use_cache, + cache_position, + **kwargs, ) - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, +class QEffDeepseekV3MoE(nn.Module): + def __qeff_init__( + self, + ): + self.all_gate_proj = torch.nn.Parameter( + torch.cat( + [exp.gate_proj.compressor.decompress_module(exp.gate_proj).T.unsqueeze(0) for exp in self.experts], + dim=0, + ) ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, + self.all_up_proj = torch.nn.Parameter( + torch.cat( + [exp.up_proj.compressor.decompress_module(exp.up_proj).T.unsqueeze(0) for exp in self.experts], dim=0 + ) ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), - indices_k, + self.all_down_proj = torch.nn.Parameter( + torch.cat( + [exp.down_proj.compressor.decompress_module(exp.down_proj).T.unsqueeze(0) for exp in self.experts], + dim=0, ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + ) + self.act_fn = self.experts[0].act_fn - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + def moe( + self, + hidden_states: torch.Tensor, + topk_indices: torch.Tensor, + topk_weights: torch.Tensor, + ): + seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + + gate_proj = self.all_gate_proj[topk_indices.flatten()] + up_proj = self.all_up_proj[topk_indices.flatten()] + down_proj = self.all_down_proj[topk_indices.flatten()] + expert_in = ( + hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size) ) + gate_out = torch.bmm(expert_in, gate_proj) + up_out = torch.bmm(expert_in, up_proj) + hidden = self.act_fn(gate_out) * up_out + expert_output = torch.bmm(hidden, down_proj) + experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) + experts_out = experts_out * topk_weights.unsqueeze(-1) + final_hidden_states = torch.einsum("abc->ac", experts_out) -ATTENTION_CLASSES = { - "eager": DeepseekV3Attention, - "flash_attention_2": DeepseekV3FlashAttention2, -} + return final_hidden_states.type(hidden_states.dtype) + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states -class DeepseekV3DecoderLayer(nn.Module): - def __init__(self, config: DeepseekV3Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) +class QEffPrefillOnlyDeepseekV3MoE(nn.Module): + def __qeff_init__( + self, + ): + for exp in self.experts: + gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + + gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) + up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) + down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + + setattr(exp, "gate_proj", gate_proj) + setattr(exp, "up_proj", up_proj) + setattr(exp, "down_proj", down_proj) + + def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + for expert_idx in range(num_experts): + expert = self.experts[expert_idx] + gate_out = expert.gate_proj(hidden_states) + up_out = expert.up_proj(hidden_states) + hidden = expert.act_fn(gate_out) * up_out + expert_output = expert.down_proj(hidden) + current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) + final_hidden_states += current_hidden_states + + print("\n\ninside prefill only moe\n") + return final_hidden_states.type(hidden_states.dtype) + + def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + """ + Forward pass of MoE block. + """ + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) + mask.scatter_(1, topk_indices, topk_weights) + if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: + hidden_states = self.moe_blocked_weights_forward( + hidden_states, topk_weights, mask, self.config.n_routed_experts + ).view(*orig_shape) + elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: + hidden_states = self.moe_blocked_forward( + hidden_states, topk_weights, mask, self.config.n_routed_experts + ).view(*orig_shape) + else: + hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) + + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) - self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + +class QEffDeepseekV3DecoderLayer(nn.Module): + """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + enable_mla: Optional[bool] = False, + mla_absorption: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) + orig_hidden_states = self.input_layernorm(hidden_states) + if enable_mla: + hidden_states, self_attn_weights, present_compressed_kvs = self.self_attn.fused_forward( + hidden_states=orig_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + compressed_kvs=compressed_kvs, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + mla_absorption=mla_absorption, + **kwargs, + ) + else: + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=orig_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = residual + hidden_states - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + if enable_mla: + outputs += (present_compressed_kvs,) + else: + outputs += (present_key_value,) return outputs -DeepseekV3_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`DeepseekV3Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DeepseekV3_START_DOCSTRING, -) -class DeepseekV3PreTrainedModel(PreTrainedModel): - config_class = DeepseekV3Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DeepseekV3DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -DeepseekV3_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DeepseekV3_START_DOCSTRING, -) -class DeepseekV3Model(DeepseekV3PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] - - Args: - config: DeepseekV3Config - """ - - def __init__(self, config: DeepseekV3Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] +class QEffDeepseekV3Model(nn.Module): + """Adapted DeepseekV3Model with batch_index and QEff rotary embedding.""" + + def __qeff_init__(self): + scaling_factor = self.config.rope_scaling["factor"] + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.config.qk_rope_head_dim, + max_position_embeddings=32 * 1024, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + **kwargs, ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + compressed_kvs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + enable_mla = getattr(self, "enable_mla", False) + + if enable_mla: + compressed_kvs = QEffDynamicCompressedKVRopeCache.from_legacy_cache(compressed_kvs) + target_len = compressed_kvs.layers[0].ckv.shape[-2] else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + target_len = past_key_values[0][0].shape[2] + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - # embed positions + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_len) hidden_states = inputs_embeds + position_embeddings = None - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None @@ -1307,32 +1046,35 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, + compressed_kvs=compressed_kvs, past_key_value=past_key_values, + batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + enable_mla=getattr(self, "enable_mla", False), + mla_absorption=getattr(self, "mla_absorption_config", None), + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache if use_cache else None + next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1341,110 +1083,70 @@ def forward( ) -class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = DeepseekV3Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder +class QEffDeepseekV3ForCausalLM(nn.Module): + """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" - def get_decoder(self): - return self.model + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.layers[0].__class__} - @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + compressed_kvs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - - >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + compressed_kvs=compressed_kvs, past_key_values=past_key_values, + batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() loss = None if labels is not None: - # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) + shift_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: @@ -1458,195 +1160,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). - - [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - DeepseekV3_START_DOCSTRING, -) -class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = DeepseekV3Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( - logits.device - ) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py deleted file mode 100644 index f9566a491d..0000000000 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_orig.py +++ /dev/null @@ -1,1653 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ---------------------------------------------------------------------------- - -import math -import warnings -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ( - ALL_LAYERNORM_LAYERS, - is_torch_greater_or_equal_than_1_13, -) -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from transformers.utils.import_utils import is_torch_fx_available - -from .configuration_deepseek import DeepseekV3Config - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "DeepseekV3Config" - - -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -class DeepseekV3RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - DeepseekV3RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) - - -class DeepseekV3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - self.max_seq_len_cached = None - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq.to(t.device)) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 -class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): - """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 -class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): - """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -# Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - -# Find dim range bounds based on rotations -def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -def yarn_linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - original_max_position_embeddings=4096, - beta_fast=32, - beta_slow=1, - mscale=1, - mscale_all_dim=0, - ): - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.beta_fast = beta_fast - self.beta_slow = beta_slow - self.mscale = mscale - self.mscale_all_dim = mscale_all_dim - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - dim = self.dim - - freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - freq_inter = 1.0 / ( - self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(seq_len, device=device, dtype=torch.float32) - - freqs = torch.outer(t, inv_freq) - - _mscale = float( - yarn_get_mscale(self.scaling_factor, self.mscale) - / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - ) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) - self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class DeepseekV3MLP(nn.Module): - def __init__(self, config, hidden_size=None, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size - - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class MoEGate(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method - self.n_group = config.n_group - self.topk_group = config.topk_group - - # topk selection algorithm - self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) - if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) - self.reset_parameters() - - def reset_parameters(self) -> None: - import torch.nn.init as init - - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - - def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - ### compute gating score - hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) - if self.scoring_func == "sigmoid": - scores = logits.sigmoid() - else: - raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") - - ### select top-k experts - if self.topk_method == "noaux_tc": - assert not self.training - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) - else: - raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") - - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor - - return topk_idx, topk_weight - - -class DeepseekV3MoE(nn.Module): - """ - A mixed expert module containing shared experts. - """ - - def __init__(self, config): - super().__init__() - self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - - if hasattr(config, "ep_size") and config.ep_size > 1: - assert config.ep_size == dist.get_world_size() - self.ep_size = config.ep_size - self.experts_per_rank = config.n_routed_experts // config.ep_size - self.ep_rank = dist.get_rank() - self.experts = nn.ModuleList( - [ - ( - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank - else None - ) - for i in range(config.n_routed_experts) - ] - ) - else: - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - for i in range(config.n_routed_experts) - ] - ) - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) - - def forward(self, hidden_states): - identity = hidden_states - orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - # flat_topk_idx = topk_idx.view(-1) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y - - @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - sorted_tokens_shape = sorted_tokens.shape - if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() - gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] - ) - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() - dist.all_to_all( - list(gathered_tokens.split(output_splits)), - list(sorted_tokens.split(input_split_sizes)), - ) - tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k - gatherd_idxs = gatherd_idxs.argsort() - sorted_tokens = gathered_tokens[gatherd_idxs] - tokens_per_expert = tokens_per_expert_post_gather - tokens_per_expert = tokens_per_expert.cpu().numpy() - - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens - if num_tokens == 0: - continue - expert = self.experts[i + self.ep_rank * self.experts_per_rank] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out) - start_idx = end_idx - - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - if self.ep_size > 1: - new_x = torch.empty_like(outs) - new_x[gatherd_idxs] = outs - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) - dist.all_to_all( - list(gathered_tokens.split(input_split_sizes)), - list(new_x.split(output_splits)), - ) - outs = gathered_tokens - - new_x = torch.empty_like(outs) - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - return final_out - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 -class DeepseekV3Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - - self.is_causal = True - - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) - else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) - - self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - config.kv_lora_rank + config.qk_rope_head_dim, - bias=config.attention_bias, - ) - self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) - self.kv_b_proj = nn.Linear( - config.kv_lora_rank, - self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), - bias=False, - ) - - self.o_proj = nn.Linear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=config.attention_bias, - ) - self._init_rope() - - self.softmax_scale = self.q_head_dim ** (-0.5) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = DeepseekV3RotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "yarn": - kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rotary_emb = DeepseekV3YarnRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **kwargs, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 -class DeepseekV3FlashAttention2(DeepseekV3Attention): - """ - DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # DeepseekV3FlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - - if self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DeepseekV3RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.softmax_scale, - ) - if self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - ( - query_states, - key_states, - value_states, - indices_q, - cu_seq_lens, - max_seq_lens, - ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), - indices_k, - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -ATTENTION_CLASSES = { - "eager": DeepseekV3Attention, - "flash_attention_2": DeepseekV3FlashAttention2, -} - - -class DeepseekV3DecoderLayer(nn.Module): - def __init__(self, config: DeepseekV3Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) - self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -DeepseekV3_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`DeepseekV3Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DeepseekV3_START_DOCSTRING, -) -class DeepseekV3PreTrainedModel(PreTrainedModel): - config_class = DeepseekV3Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DeepseekV3DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -DeepseekV3_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DeepseekV3_START_DOCSTRING, -) -class DeepseekV3Model(DeepseekV3PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] - - Args: - config: DeepseekV3Config - """ - - def __init__(self, config: DeepseekV3Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = DeepseekV3Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - - >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). - - [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - DeepseekV3_START_DOCSTRING, -) -class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = DeepseekV3Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( - logits.device - ) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py deleted file mode 100644 index c4688edfee..0000000000 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py +++ /dev/null @@ -1,1162 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ---------------------------------------------------------------------------- - -import math -import os -from typing import Dict, List, Optional, Tuple, Type, Union - -import torch -import torch.nn.functional as F -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast - -from QEfficient.blocking.attention_blocking import ( - AttentionBlockingConfig, - generic_blocked_attention_interface, - generic_blocked_mla_attention_interface, -) -from QEfficient.customop.rms_norm import CustomRMSNormFunc -from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache -from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - -# Find dim range bounds based on rotations -def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -def yarn_linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -class QEffDeepseekV3CustomRMSNormAIC(nn.Module): - """ - RMSNorm module that works by replacing the current module with compiler known custom-op. - """ - - def forward(self, hidden_states): - """ - Forward pass of the RMSNorm module. - - Args: - hidden_states (torch.Tensor): Input tensor to be normalized. - - Returns: - torch.Tensor: Normalized tensor. - """ - return CustomRMSNormFunc.apply( - hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps - ) - - -class DeepseekV3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - self.max_seq_len_cached = None - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq.to(t.device)) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - # def forward(self, x, position_ids): - # seq_len = torch.max(position_ids) + 1 - # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - # # Use position_ids to slice the precomputed caches - # cos = self.cos_cached[position_ids] - # sin = self.sin_cached[position_ids] - # return cos.to(x.dtype), sin.to(x.dtype) - - -class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - original_max_position_embeddings=4096, - beta_fast=32, - beta_slow=1, - mscale=1, - mscale_all_dim=0, - ): - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.beta_fast = beta_fast - self.beta_slow = beta_slow - self.mscale = mscale - self.mscale_all_dim = mscale_all_dim - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - dim = self.dim - - freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - freq_inter = 1.0 / ( - self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(seq_len, device=device, dtype=torch.float32) - - freqs = torch.outer(t, inv_freq) - - _mscale = float( - yarn_get_mscale(self.scaling_factor, self.mscale) - / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - ) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) - self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class QEffDeepseekV3Attention(nn.Module): - """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" - - def __qeff_init__( - self, - ): - q_up, q_rope = self.q_b_proj.weight.T.view( - -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim - ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - self.q_up = torch.nn.Parameter(q_up.detach().clone()) - - q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) - self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) - - k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) - v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) - - self.k_up = torch.nn.Parameter(k_up.detach()) - self.v_up = torch.nn.Parameter(v_up.detach()) - per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - per_head_k_up = ( - self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) - ) - per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) - self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) - self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) - per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) - self.per_head_k_up_normal = torch.nn.Parameter(per_head_k_up_normal.detach().clone()) - - fusedqk = torch.bmm(per_head_q_up, per_head_k_up).reshape( - -1, self.num_heads, self.q_lora_rank, self.kv_lora_rank - ) - self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) - - def fused_forward_h_blocking( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - compressed_kvs: Optional[torch.Tensor] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - mla_absorption: Optional[Dict[str, bool]] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - - kva = compressed_kv[:, :, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] - - q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.matmul(q_a_proj_out, self.q_rope) - q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - - kva = self.kv_a_layernorm(kva) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if compressed_kvs is not None: - kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) - - cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - - attn_output, attn_weights = generic_blocked_mla_attention_interface( - module=self, - q_a_proj_out=q_a_proj_out, - fusedqk=self.fusedqk, - q_nope=q_nope, - q_pe=q_pe, - kva=kva, - k_pe=k_pe, - per_head_q_up=self.per_head_q_up, - per_head_k_up=self.per_head_k_up, - per_head_v_up=self.per_head_v_up, - per_head_k_up_normal=self.per_head_k_up_normal, - attention_mask=attention_mask, - scaling=self.softmax_scale, - mla_absorption=mla_absorption, - blocking_config=blocking_config, - position_ids=position_ids, - **kwargs, - ) - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, compressed_kvs - - def fused_forward_kv_blocking( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - compressed_kvs: Optional[torch.Tensor] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - mla_absorption: Optional[Dict[str, bool]] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - - kva = compressed_kv[:, :, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] - - q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - q_pe = torch.matmul(q_a_proj_out, self.q_rope) - q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - - kva = self.kv_a_layernorm(kva) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - - ## Write Only - if compressed_kvs is not None: - compressed_kvs.write_only_ckv(kva, self.layer_idx, cache_kwargs) - - cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - if compressed_kvs is not None: - compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) - - if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) - else: - enable_absorption = False - - if enable_absorption: - if absorb_online: - qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) - dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) - else: - dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk) - qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe), dim=-1) - query = qkupTrope_nope - else: - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - qnope_rope = torch.cat((q_nope, q_pe), dim=-1) - query = qnope_rope - - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - - attn_output, attn_weights = generic_blocked_mla_attention_interface( - module=self, - query=query, - per_head_k_up_normal=self.per_head_k_up_normal, - per_head_v_up=self.per_head_v_up, - attention_mask=attention_mask, - scaling=self.softmax_scale, - layer_idx=self.layer_idx, - compressed_kvs=compressed_kvs, - mla_absorption=mla_absorption, - blocking_config=blocking_config, - position_ids=position_ids, - **kwargs, - ) - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None, compressed_kvs - - def fused_forward_orig( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - compressed_kvs: Optional[torch.Tensor] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - mla_absorption: Optional[Dict[str, bool]] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - print("using orig forward") - - # ---- KV compression ---- - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - - kva = compressed_kv[:, :, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] - - # ---- Q projections ---- - q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) - - q_pe = torch.bmm(q_a_proj_out, self.q_rope) - q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) - - kva = self.kv_a_layernorm(kva) - - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if compressed_kvs is not None: - kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) - - # ---- MLA absorption flags ---- - if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) - else: - enable_absorption = False - - head_block_size = kva.shape[1] - p = self.num_heads // head_block_size - seq_kv = kva.shape[2] - - # ---- Rotary ---- - cos, sin = self.rotary_emb(q_pe, seq_len=32 * 1024) # Doesn't need q_pe as head_dim is initialized - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) - - kva_expanded = ( - kva.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) - ) - - k_pe_expanded = ( - k_pe.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.qk_rope_head_dim) - ) - - v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) - value_states = torch.matmul(kva_expanded, v_up_per_head) - - if enable_absorption: - if absorb_online: - out = torch.matmul(self.per_head_q_up, self.per_head_k_up) - q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) - else: - q_nope_compressed = torch.matmul( - q_a_proj_out.unsqueeze(1), - self.fusedqk, - ) - query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) - key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) - else: - q_nope = torch.bmm(q_a_proj_out, self.q_up) - q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) - query_states = torch.cat((q_nope, q_pe), dim=-1) - - k_up_per_head = ( - self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) - ) - k_nope = torch.matmul(kva_expanded, k_up_per_head) - key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, - torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype), - attn_weights, - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) - ## Do v_proj here - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) - - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, compressed_kvs - - def fused_forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - compressed_kvs: Optional[torch.Tensor] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - mla_absorption: Optional[Dict[str, bool]] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.environ.get("KIMI_BLOCKING", "0") == "h": - return self.fused_forward_h_blocking( - hidden_states, - position_embeddings, - attention_mask, - position_ids, - past_key_value, - compressed_kvs, - batch_index, - output_attentions, - use_cache, - cache_position, - mla_absorption, - **kwargs, - ) - elif os.environ.get("KIMI_BLOCKING", "0") == "kv": - return self.fused_forward_kv_blocking( - hidden_states, - position_embeddings, - attention_mask, - position_ids, - past_key_value, - compressed_kvs, - batch_index, - output_attentions, - use_cache, - cache_position, - mla_absorption, - **kwargs, - ) - else: - return self.fused_forward_orig( - hidden_states, - position_embeddings, - attention_mask, - position_ids, - past_key_value, - compressed_kvs, - batch_index, - output_attentions, - use_cache, - cache_position, - mla_absorption, - **kwargs, - ) - - def forward_full_kv( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - - q_nope = q[:, :, :, : self.qk_nope_head_dim] - q_pe = q[:, :, :, self.qk_nope_head_dim :] - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - - kva = compressed_kv[:, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, self.kv_lora_rank :] - - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(kva)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope = kv[:, :, :, : self.qk_nope_head_dim] - value_states = kv[:, :, :, self.qk_nope_head_dim :] - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = torch.cat((q_nope, q_pe), -1) - k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) - key_states = torch.cat((k_nope, k_pe_new), -1) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, past_key_value - - def forward_full_kv_h_blocking( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - - q_nope = q[:, :, :, : self.qk_nope_head_dim] - q_pe = q[:, :, :, self.qk_nope_head_dim :] - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) - - kva = compressed_kv[:, :, :, : self.kv_lora_rank] - k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] - - kv = ( - self.kv_b_proj(self.kv_a_layernorm(kva)) - .view( - bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) # TODO : split this matmul #with k_up and v_up - .transpose(1, 2) - ) - - k_nope = kv[:, :, :, : self.qk_nope_head_dim] - value_states = kv[:, :, :, self.qk_nope_head_dim :] - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) - q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = torch.cat((q_nope, q_pe), -1) - k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) - key_states = torch.cat((k_nope, k_pe_new), -1) - - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - - attn_output, attn_weights = generic_blocked_attention_interface( - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - scaling=self.softmax_scale, - layer_idx=self.layer_idx, - past_key_value=past_key_value, - blocking_config=blocking_config, - batch_index=batch_index, - position_ids=position_ids, - ) - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, past_key_value - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.environ.get("KIMI_BLOCKING", "0") == "h": - return self.forward_full_kv_h_blocking( - hidden_states, - position_embeddings, - attention_mask, - position_ids, - past_key_value, - batch_index, - output_attentions, - use_cache, - cache_position, - **kwargs, - ) - else: - return self.forward_full_kv( - hidden_states, - position_embeddings, - attention_mask, - position_ids, - past_key_value, - batch_index, - output_attentions, - use_cache, - cache_position, - **kwargs, - ) - - -class QEffDeepseekV3MoE(nn.Module): - def __qeff_init__( - self, - ): - self.all_gate_proj = torch.nn.Parameter( - torch.cat( - [exp.gate_proj.compressor.decompress_module(exp.gate_proj).T.unsqueeze(0) for exp in self.experts], - dim=0, - ) - ) - self.all_up_proj = torch.nn.Parameter( - torch.cat( - [exp.up_proj.compressor.decompress_module(exp.up_proj).T.unsqueeze(0) for exp in self.experts], dim=0 - ) - ) - self.all_down_proj = torch.nn.Parameter( - torch.cat( - [exp.down_proj.compressor.decompress_module(exp.down_proj).T.unsqueeze(0) for exp in self.experts], - dim=0, - ) - ) - self.act_fn = self.experts[0].act_fn - - def moe( - self, - hidden_states: torch.Tensor, - topk_indices: torch.Tensor, - topk_weights: torch.Tensor, - ): - seq_len, _ = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - - gate_proj = self.all_gate_proj[topk_indices.flatten()] - up_proj = self.all_up_proj[topk_indices.flatten()] - down_proj = self.all_down_proj[topk_indices.flatten()] - expert_in = ( - hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size) - ) - gate_out = torch.bmm(expert_in, gate_proj) - up_out = torch.bmm(expert_in, up_proj) - hidden = self.act_fn(gate_out) * up_out - expert_output = torch.bmm(hidden, down_proj) - experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) - experts_out = experts_out * topk_weights.unsqueeze(-1) - - final_hidden_states = torch.einsum("abc->ac", experts_out) - - return final_hidden_states.type(hidden_states.dtype) - - def forward(self, hidden_states): - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - -class QEffPrefillOnlyDeepseekV3MoE(nn.Module): - def __qeff_init__( - self, - ): - for exp in self.experts: - gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) - - gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) - up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) - down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) - - setattr(exp, "gate_proj", gate_proj) - setattr(exp, "up_proj", up_proj) - setattr(exp, "down_proj", down_proj) - - def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - for expert_idx in range(num_experts): - expert = self.experts[expert_idx] - gate_out = expert.gate_proj(hidden_states) - up_out = expert.up_proj(hidden_states) - hidden = expert.act_fn(gate_out) * up_out - expert_output = expert.down_proj(hidden) - current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) - final_hidden_states += current_hidden_states - - print("\n\ninside prefill only moe\n") - return final_hidden_states.type(hidden_states.dtype) - - def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - r""" - CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused - to not have to do a loop here (deepseek has 256 experts soooo yeah). - """ - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) - expert_mask = expert_mask.permute(2, 0, 1) - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) - - if token_indices.numel() > 0: - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, weighted_output) - - # in original deepseek, the output of the experts are gathered once we leave this module - # thus the moe module is itelsf an IsolatedParallel module - # and all expert are "local" meaning we shard but we don't gather - return final_hidden_states.type(hidden_states.dtype) - - def forward(self, hidden_states): - """ - Forward pass of MoE block. - """ - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) - # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) - mask.scatter_(1, topk_indices, topk_weights) - if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: - hidden_states = self.moe_blocked_weights_forward( - hidden_states, topk_weights, mask, self.config.n_routed_experts - ).view(*orig_shape) - elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: - hidden_states = self.moe_blocked_forward( - hidden_states, topk_weights, mask, self.config.n_routed_experts - ).view(*orig_shape) - else: - hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) - - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - -class QEffDeepseekV3DecoderLayer(nn.Module): - """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - compressed_kvs: Optional[torch.Tensor] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - enable_mla: Optional[bool] = False, - mla_absorption: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - orig_hidden_states = self.input_layernorm(hidden_states) - if enable_mla: - hidden_states, self_attn_weights, present_compressed_kvs = self.self_attn.fused_forward( - hidden_states=orig_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - position_embeddings=position_embeddings, - past_key_value=past_key_value, - compressed_kvs=compressed_kvs, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - mla_absorption=mla_absorption, - **kwargs, - ) - else: - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=orig_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - position_embeddings=position_embeddings, - past_key_value=past_key_value, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - if use_cache: - if enable_mla: - outputs += (present_compressed_kvs,) - else: - outputs += (present_key_value,) - - return outputs - - -class QEffDeepseekV3Model(nn.Module): - """Adapted DeepseekV3Model with batch_index and QEff rotary embedding.""" - - def __qeff_init__(self): - scaling_factor = self.config.rope_scaling["factor"] - kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rotary_emb = DeepseekV3YarnRotaryEmbedding( - self.config.qk_rope_head_dim, - max_position_embeddings=32 * 1024, - scaling_factor=scaling_factor, - base=self.config.rope_theta, - **kwargs, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - compressed_kvs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - batch_index: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - enable_mla = getattr(self, "enable_mla", False) - - if enable_mla: - compressed_kvs = QEffDynamicCompressedKVRopeCache.from_legacy_cache(compressed_kvs) - target_len = compressed_kvs.layers[0].ckv.shape[-2] - else: - target_len = past_key_values[0][0].shape[2] - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_len) - hidden_states = inputs_embeds - position_embeddings = None - - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - compressed_kvs=compressed_kvs, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - enable_mla=getattr(self, "enable_mla", False), - mla_absorption=getattr(self, "mla_absorption_config", None), - **kwargs, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class QEffDeepseekV3ForCausalLM(nn.Module): - """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" - - def get_submodules_for_export(self) -> Type[nn.Module]: - """ - Return the set of class used as the repeated layer across the model for subfunction extraction. - Notes: - This method should return the *class object* (not an instance). - Downstream code can use this to find/build subfunctions for repeated blocks. - """ - return {self.model.layers[0].__class__} - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - compressed_kvs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - batch_index: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - compressed_kvs=compressed_kvs, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - logits = self.lm_head(hidden_states).float() - - loss = None - if labels is not None: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2cdb107956..559d7e86ed 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -249,7 +249,7 @@ from QEfficient.transformers.models.deberta_v2.modeling_deberta_v2 import ( QEffDisentangledSelfAttention, ) -from QEfficient.transformers.models.deepseek_v3.modeling_deepseek_qeff import ( +from QEfficient.transformers.models.deepseek_v3.modeling_deepseek import ( QEffDeepseekV3Attention, QEffDeepseekV3CustomRMSNormAIC, QEffDeepseekV3DecoderLayer, diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py deleted file mode 100644 index 554a9a6c55..0000000000 --- a/examples/kimi_k2/export_kimik2.py +++ /dev/null @@ -1,59 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ---------------------------------------------------------------------------- - -import os - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM - -# parameters to be configured - -prompt = "Once upon a time," -num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" when using MLA and this should be equal to TS in that case. -num_hidden_layers = 2 -TS = 4 -enable_mla = True -mla_absorption_config = {"enable": False, "online": False} -qaic_config = None - -if os.environ.get("KIMI_BLOCKING", "0") == "kv": - qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} - num_kv_heads_repeat = 1 - -if os.environ.get("KIMI_BLOCKING", "0") == "h": - if enable_mla: - num_kv_heads_repeat = TS - else: - num_kv_heads_repeat = 1 # head replication is not needed if MLA is not enabled - qaic_config = {"enable_blocking": True, "blocking_mode": "h", "head_block_size": TS} - -model = AutoModelForCausalLM.from_pretrained( - "moonshotai/Kimi-K2-Thinking", - torch_dtype=torch.float32, - num_hidden_layers=num_hidden_layers, - trust_remote_code=True, -) -tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) - -qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat, qaic_config=qaic_config) - -qpc_path = qeff_model.compile( - prefill_seq_len=1, - ctx_len=16 * 1024, - enable_mla=enable_mla, - mla_absorption_config=mla_absorption_config, - mxfp6_matmul=True, - mxint8_kv_cache=False, - num_devices=TS, - num_cores=16, - # use_onnx_subfunctions=True, - qaic_config=qaic_config, -) - -qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index 6724802e83..7fb2d41c4b 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -13,6 +13,8 @@ from QEfficient import QEFFAutoModelForCausalLM +# parameters to be configured + prompt = "Once upon a time," num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" when using MLA and this should be equal to TS in that case. num_hidden_layers = 2 diff --git a/examples/kimi_k2/run_orig_kimi_k2.py b/examples/kimi_k2/run_orig_kimi_k2.py deleted file mode 100644 index 03052dfc31..0000000000 --- a/examples/kimi_k2/run_orig_kimi_k2.py +++ /dev/null @@ -1,36 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ---------------------------------------------------------------------------- - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -model = AutoModelForCausalLM.from_pretrained( - "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd", - torch_dtype=torch.float32, - num_hidden_layers=2, - trust_remote_code=True, -) -tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) - -prompt = "Once upon a time," -inputs = tokenizer(prompt, return_tensors="pt").to(model.device) -with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=10, - do_sample=False, - use_cache=False, - ) - -response = tokenizer.decode(outputs[0], skip_special_tokens=True) -print(response) - -""" -Original Pytorch, kimi-k2 thinking: -Prompt: Once upon a time, -Completion : ?? branchesrupt??? flushedakislottery rehearsallesi -""" From 5c23f53590575d5fadd7c2c8242f6fa7d0c65360 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Sun, 26 Apr 2026 14:25:20 +0530 Subject: [PATCH 36/51] remove env variables Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 20 +++++-- QEfficient/blocking/blocking_configurator.py | 8 ++- .../models/deepseek_v3/modeling_deepseek.py | 21 ++----- .../transformers/models/modeling_auto.py | 10 +--- .../transformers/models/pytorch_transforms.py | 14 ++--- examples/kimi_k2/export_kimik2.py | 51 +++++++++++++++++ examples/kimi_k2/run_kimik2.py | 55 +++++-------------- 7 files changed, 102 insertions(+), 77 deletions(-) create mode 100644 examples/kimi_k2/export_kimik2.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3570857c35..2777dac54a 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -30,6 +30,7 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.models.pytorch_transforms import ( BlockingAttentionTransform, + ReplicateKVHeadTransform, ) from QEfficient.utils import ( constants, @@ -451,11 +452,22 @@ def transform( qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - if getattr(self.model, "config", None) or getattr(self.model.model, "config", None): + model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None) + + if model_config: + if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): + if qaic_config: + if qaic_config.get("blocking_mode", None) == "h": + qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) + num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( + self.model, num_kv_heads_repeat + ) + if replicate_kv_transformed: + self.hash_params["config"] = self.model.config.to_diff_dict() + blocking_config = build_transformer_blocking_config_for_transform( - getattr(self.model, "config", None) - if getattr(self.model, "config", None) - else getattr(self.model.model, "config", None), + model_config, ctx_len=ctx_len, seq_len=seq_len, bs=bs, diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index 54e09c205a..0e47154f58 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -23,6 +23,7 @@ FP16_BYTES = 2 DEFAULT_NUM_HEADS = 64 + def _infer_head_dim(model_config: Any, num_heads: int) -> int: head_dim = get_attr_or_key(model_config, ("attention_head_dim", "head_dim", "head_dim_per_head")) if head_dim is not None: @@ -166,8 +167,11 @@ def get_num_kv_blocks_for_mla(q_len, num_heads, ctx_len): kv = max_kv_block_size(q_len, budget_bytes, num_heads) b1 = matmul1_bytes(q_len, kv, num_heads) b2 = matmul2_bytes(q_len, kv, num_heads) + assert b1 < budget_bytes, "matmul1 is not under the budget" assert b2 < budget_bytes, "matmul2 is not under the budget" + + kv_block_size = ctx_len kv_block_size_list = block_candidates_generator(ctx_len) for i in range(len(kv_block_size_list) - 1): if kv_block_size_list[i] < kv < kv_block_size_list[i + 1]: @@ -285,8 +289,10 @@ def build_transformer_blocking_config( int(data_bytes), blocking_mode=blocking_mode, ) + if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): - attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len) + if "kv" in blocking_mode: + attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len) resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") effective_mode = _resolve_effective_blocking_mode(attention_cfg, resolved_mode) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index c4688edfee..06e8ca67a4 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -118,16 +118,6 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len].to(dtype=x.dtype), ) - # def forward(self, x, position_ids): - # seq_len = torch.max(position_ids) + 1 - # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - # # Use position_ids to slice the precomputed caches - # cos = self.cos_cached[position_ids] - # sin = self.sin_cached[position_ids] - # return cos.to(x.dtype), sin.to(x.dtype) - class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): def __init__( @@ -426,7 +416,6 @@ def fused_forward_orig( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - print("using orig forward") # ---- KV compression ---- compressed_kv = self.kv_a_proj_with_mqa(hidden_states) @@ -529,7 +518,8 @@ def fused_forward( mla_absorption: Optional[Dict[str, bool]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.environ.get("KIMI_BLOCKING", "0") == "h": + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + if getattr(blocking_config, "mode", None) == "h": return self.fused_forward_h_blocking( hidden_states, position_embeddings, @@ -544,7 +534,7 @@ def fused_forward( mla_absorption, **kwargs, ) - elif os.environ.get("KIMI_BLOCKING", "0") == "kv": + elif getattr(blocking_config, "mode", None) == "kv": return self.fused_forward_kv_blocking( hidden_states, position_embeddings, @@ -654,6 +644,7 @@ def forward_full_kv_h_blocking( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: @@ -720,7 +711,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.environ.get("KIMI_BLOCKING", "0") == "h": + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + if getattr(blocking_config, "mode", None) == "h": return self.forward_full_kv_h_blocking( hidden_states, position_embeddings, @@ -836,7 +828,6 @@ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_ma current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) final_hidden_states += current_hidden_states - print("\n\ninside prefill only moe\n") return final_hidden_states.type(hidden_states.dtype) def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 67aa5a0094..5753f3e241 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -53,7 +53,6 @@ PrefillOnlyChunkedTransform, PrefillOnlyExternalModuleMapperTransform, PrefillOnlyTransform, - ReplicateKVHeadTransform, RevertPrefillKeepAttentionTransform, RevertPrefillOnlyExternalModuleMapperTransform, RevertPrefillOnlyTransform, @@ -2884,11 +2883,6 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached - if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): - self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) - if replicate_kv_transformed: - self.hash_params["config"] = model.config.to_diff_dict() - # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the @@ -3220,7 +3214,9 @@ def export( if enable_mla: for lay in self.model.model.layers: if lay is not None: - num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // 576 + num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( + self.model.config.kv_lora_rank + self.model.config.qk_rope_head_dim + ) example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 559d7e86ed..3597cac041 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -814,22 +814,20 @@ def _get_text_model(model): raise AttributeError("No suitable text model found in the provided model.") @classmethod - def apply(cls, model: nn.Module, **kwargs) -> nn.Module: + def apply(cls, model: nn.Module, num_kv_heads_repeat: int = 1) -> nn.Module: """ Replicates KV heads in attention modules based on provided multiplier. Args: model: The model to apply the transform to. - kwargs: Additional arguments for the transformation. Includes: - - num_kv_heads_repeat: The number of times to repeat the KV heads. + num_kv_heads_repeat: The number of times to repeat the KV heads. """ - n_repeat = kwargs.pop("num_kv_heads_repeat", 1) transformed = False - if n_repeat is not None and n_repeat > 1: + if num_kv_heads_repeat is not None and num_kv_heads_repeat > 1: text_model = cls._get_text_model(model) orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads - new_kv_heads = n_repeat * orig_kv_heads + new_kv_heads = num_kv_heads_repeat * orig_kv_heads text_model.config.orig_kv_heads = orig_kv_heads text_model.config.num_key_value_heads = new_kv_heads @@ -844,7 +842,7 @@ def apply(cls, model: nn.Module, **kwargs) -> nn.Module: head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim cls._duplicate_weights_for_linear_layer( - attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size + attn.kv_a_proj_with_mqa, orig_kv_heads, num_kv_heads_repeat, head_dim, hidden_size ) return model, transformed @@ -1043,9 +1041,9 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "forward_full_kv": QEffDeepseekV3Attention.forward_full_kv, "forward_full_kv_h_blocking": QEffDeepseekV3Attention.forward_full_kv_h_blocking, "fused_forward": QEffDeepseekV3Attention.fused_forward, + "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, "fused_forward_kv_blocking": QEffDeepseekV3Attention.fused_forward_kv_blocking, "fused_forward_orig": QEffDeepseekV3Attention.fused_forward_orig, - "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, "DeepseekV3RMSNorm": { diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py new file mode 100644 index 0000000000..1f624361d4 --- /dev/null +++ b/examples/kimi_k2/export_kimik2.py @@ -0,0 +1,51 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +# parameters to be configured +prompt = "Once upon a time," +num_hidden_layers = 2 +TS = 4 +enable_mla = True +mla_absorption_config = {"enable": False, "online": False} +# qaic_config = None #for orig_forward +# qaic_config = {"num_kv_heads_repeat": TS} #with head replication for orig_forward +# qaic_config = {"enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} # for KV blocking + +# model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" +model_path = ( + "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" +) +model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) + +qeff_model = QEFFAutoModelForCausalLM(model) + +prefill_seq_len = 1 +ctx_len = 16 * 1024 + +qpc_path = qeff_model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + enable_mla=enable_mla, + mla_absorption_config=mla_absorption_config, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=TS, + num_cores=16, + use_onnx_subfunctions=True, + qaic_config=qaic_config, +) + +qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index 7fb2d41c4b..c33e5cd4b4 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -5,8 +5,6 @@ # # ---------------------------------------------------------------------------- -import os - import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -14,25 +12,15 @@ from QEfficient import QEFFAutoModelForCausalLM # parameters to be configured - prompt = "Once upon a time," -num_kv_heads_repeat = 1 # When using KIMI_BLOCKING="kv", make sure num_kv_heads_repeat is set to 1. Use only for KIMI_BLOCKING="h" when using MLA and this should be equal to TS in that case. num_hidden_layers = 2 TS = 4 enable_mla = True mla_absorption_config = {"enable": False, "online": False} -qaic_config = None - -if os.environ.get("KIMI_BLOCKING", "0") == "kv": - qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} - num_kv_heads_repeat = 1 - -if os.environ.get("KIMI_BLOCKING", "0") == "h": - if enable_mla: - num_kv_heads_repeat = TS - else: - num_kv_heads_repeat = 1 # head replication is not needed if MLA is not enabled - qaic_config = {"enable_blocking": True, "blocking_mode": "h", "head_block_size": TS} +# qaic_config = None #for orig_forward +# qaic_config = {"num_kv_heads_repeat": TS} #with head replication for orig_forward +# qaic_config = {"enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} # for KV blocking # model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" model_path = ( @@ -44,7 +32,7 @@ tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) PREFILL_SEQ_LEN = 32 -CTX_LEN = 128 +CTX_LEN = 8192 generation_len = 10 generated_ids = [] @@ -57,7 +45,8 @@ # out = model(**inputs) # predictions = torch.argmax(out.logits, dim=-1) -qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat, qaic_config=qaic_config) +qeff_model = QEFFAutoModelForCausalLM(model) +qeff_model.transform(ctx_len=CTX_LEN, seq_len=PREFILL_SEQ_LEN, bs=1, num_devices=TS, qaic_config=qaic_config) qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) @@ -72,8 +61,12 @@ model.config.qk_nope_head_dim + model.config.qk_rope_head_dim, ) pad_shape_v = (1, model.config.num_attention_heads, CTX_LEN, model.config.v_head_dim) -pad_shape_ckv = (1, num_kv_heads_repeat, CTX_LEN, model.config.kv_lora_rank) -pad_shape_k_pe = (1, num_kv_heads_repeat, CTX_LEN, model.config.qk_rope_head_dim) + +num_heads = model.model.layers[0].self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( + model.config.kv_lora_rank + model.config.qk_rope_head_dim +) +pad_shape_ckv = (1, num_heads, CTX_LEN, model.config.kv_lora_rank) +pad_shape_k_pe = (1, num_heads, CTX_LEN, model.config.qk_rope_head_dim) past_key_values = [] compressed_kvs = [] @@ -97,9 +90,6 @@ prefill_qeff_out = qeff_model.model(**inputs) -# breakpoint() -# assert (prefill_qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 - position_ids = inputs["position_ids"] qeff_out = prefill_qeff_out qeff_generated_ids = [] @@ -124,22 +114,3 @@ print("QEFF Transformed Model Outputs (Torch CPU): \n") print("Prompt:", repr(prompt)) print("Completion:", repr(predicted_string)) - - -prefill_seq_len = 1 -ctx_len = 16 * 1024 - -qpc_path = qeff_model.compile( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - enable_mla=enable_mla, - mla_absorption_config=mla_absorption_config, - mxfp6_matmul=True, - mxint8_kv_cache=False, - num_devices=TS, - num_cores=16, - use_onnx_subfunctions=True, - qaic_config=qaic_config, -) - -qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 8e7ddbe4ed9ea6bf727bcf4386488dc6df8f66a9 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 27 Apr 2026 01:21:20 +0530 Subject: [PATCH 37/51] simplify mla_absorption_config Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 9 +--- .../blocking/blocked_attention_forwards.py | 17 +++++--- .../models/deepseek_v3/modeling_deepseek.py | 43 +++++++++++-------- .../transformers/models/modeling_auto.py | 27 ++++++------ examples/kimi_k2/export_kimik2.py | 12 +++--- examples/kimi_k2/run_kimik2.py | 11 +++-- 6 files changed, 63 insertions(+), 56 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2777dac54a..30ee4d5696 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -287,7 +287,6 @@ def _export( """ # TODO: Hack for retain_full_kv, handle this outside export_kwargs.pop("retain_full_kv", None) - export_kwargs.pop("enable_mla", None) export_kwargs.pop("mla_absorption_config", None) onnx_path = export_dir / f"{self.model_name}.onnx" @@ -393,8 +392,7 @@ def get_onnx_path( offload_pt_weights: Optional[bool] = True, use_onnx_subfunctions: Optional[bool] = False, retain_full_kv: Optional[bool] = False, - enable_mla: Optional[bool] = False, - mla_absorption_config: Optional[bool] = False, + mla_absorption_config: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, **compiler_options, ): @@ -402,7 +400,6 @@ def get_onnx_path( "offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions, "retain_full_kv": retain_full_kv, - "enable_mla": enable_mla, "mla_absorption_config": mla_absorption_config, } @@ -501,8 +498,7 @@ def _compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, - enable_mla: Optional[bool] = False, - mla_absorption_config: Optional[Dict[str, bool]] = False, + mla_absorption_config: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, specialization_module_name: Optional[str] = None, **compiler_options, @@ -543,7 +539,6 @@ def _compile( offload_pt_weights, use_onnx_subfunctions, retain_full_kv, - enable_mla, mla_absorption_config, num_devices=mdp_ts_num_devices, qaic_config=qaic_config, diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 5e4ed12f50..4da283b462 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -889,9 +889,12 @@ def blocked_kv_mla_attention_forward( start_index=start_index, ) - enable_absorption = mla_absorption.get("enable", False) + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + else: + absorption = False - if enable_absorption: + if absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] @@ -963,10 +966,10 @@ def blocked_h_mla_attention_forward( masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=q_pe.device) if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) else: - enable_absorption = False + absorption = False h_output_blocks = [] h_attn_blocks = [] @@ -975,8 +978,8 @@ def blocked_h_mla_attention_forward( h_start = head_block_idx * head_block_size h_end = min(h_start + head_block_size, num_heads) - if enable_absorption: - if absorb_online: + if absorption: + if online: qup_kupT = torch.matmul(per_head_q_up[:, h_start:h_end, :, :], per_head_k_up[:, h_start:h_end, :, :]) dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) else: diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 06e8ca67a4..1b67b3df9f 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -359,13 +359,13 @@ def fused_forward_kv_blocking( compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) else: - enable_absorption = False + absorption = False - if enable_absorption: - if absorb_online: + if absorption: + if online: qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) else: @@ -438,10 +438,10 @@ def fused_forward_orig( # ---- MLA absorption flags ---- if mla_absorption is not None: - enable_absorption = mla_absorption.get("enable", False) - absorb_online = mla_absorption.get("online", False) + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) else: - enable_absorption = False + absorption = False head_block_size = kva.shape[1] p = self.num_heads // head_block_size @@ -465,8 +465,8 @@ def fused_forward_orig( v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) value_states = torch.matmul(kva_expanded, v_up_per_head) - if enable_absorption: - if absorb_online: + if absorption: + if online: out = torch.matmul(self.per_head_q_up, self.per_head_k_up) q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) else: @@ -898,13 +898,16 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - enable_mla: Optional[bool] = False, - mla_absorption: Optional[bool] = False, + mla_absorption: Optional[Dict[str, bool]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states orig_hidden_states = self.input_layernorm(hidden_states) - if enable_mla: + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + if cache_compressed: hidden_states, self_attn_weights, present_compressed_kvs = self.self_attn.fused_forward( hidden_states=orig_hidden_states, attention_mask=attention_mask, @@ -943,7 +946,7 @@ def forward( if output_attentions: outputs += (self_attn_weights,) if use_cache: - if enable_mla: + if cache_compressed: outputs += (present_compressed_kvs,) else: outputs += (present_key_value,) @@ -1006,9 +1009,14 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - enable_mla = getattr(self, "enable_mla", False) - if enable_mla: + mla_absorption = getattr(self, "mla_absorption_config", None) + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + + if cache_compressed: compressed_kvs = QEffDynamicCompressedKVRopeCache.from_legacy_cache(compressed_kvs) target_len = compressed_kvs.layers[0].ckv.shape[-2] else: @@ -1046,8 +1054,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - enable_mla=getattr(self, "enable_mla", False), - mla_absorption=getattr(self, "mla_absorption_config", None), + mla_absorption=mla_absorption, **kwargs, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5753f3e241..976f99bbcb 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2767,10 +2767,8 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): def mla( self, - enable_mla: Optional[bool] = False, - mla_absorption_config: Optional[Dict[str, bool]] = False, + mla_absorption_config: Optional[Dict[str, bool]] = None, ): - setattr(self.model.model, "enable_mla", enable_mla) setattr(self.model.model, "mla_absorption_config", mla_absorption_config) def prefill( @@ -3093,9 +3091,6 @@ def export( enable_chunking = kwargs.get("enable_chunking", False) # TODO: HACK handle better - if enable_mla := kwargs.get("enable_mla", False): - self.hash_params["enable_mla"] = enable_mla - setattr(self.model.model, "enable_mla", enable_mla) if mla_absorption_config := kwargs.get("mla_absorption_config", None): self.hash_params["mla_absorption_config"] = mla_absorption_config setattr(self.model.model, "mla_absorption_config", mla_absorption_config) @@ -3211,7 +3206,11 @@ def export( output_names.append(f"past_{kv}.{i}_RetainedState") if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): - if enable_mla: + if mla_absorption_config is not None: + cache_compressed = mla_absorption_config.get("cache_compressed", False) + else: + cache_compressed = False + if cache_compressed: for lay in self.model.model.layers: if lay is not None: num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( @@ -3423,8 +3422,7 @@ def compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, - enable_mla: Optional[bool] = False, - mla_absorption_config: Optional[Dict[str, bool]] = False, + mla_absorption_config: Optional[Dict[str, bool]] = None, **compiler_options, ) -> str: """ @@ -3506,8 +3504,12 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ - if mla_absorption_config and not enable_mla: - logger.warning("mla_absorption_config will be ignored as enable_mla is set to False") + if mla_absorption_config is not None: + cache_compressed = mla_absorption_config.get("cache_compressed", False) + else: + cache_compressed = False + if mla_absorption_config is not None and not cache_compressed: + logger.warning("mla_absorption_config will be ignored as cache_compressed is set to False") if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: logger.warning( "`kv_cache_batch_size` or `full_batch_size` is being passed" @@ -3636,7 +3638,7 @@ def compile( kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] # --- Compilation --- custom_io = {} - if not enable_mla: + if not cache_compressed: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): for kv in ["key", "value"]: @@ -3665,7 +3667,6 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, - enable_mla=enable_mla, mla_absorption_config=mla_absorption_config, **compiler_options, ) diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 1f624361d4..891cc998cd 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -14,12 +14,15 @@ prompt = "Once upon a time," num_hidden_layers = 2 TS = 4 -enable_mla = True -mla_absorption_config = {"enable": False, "online": False} +mla_absorption_config = {"cache_compressed": True, "absorption": False, "online": False} # qaic_config = None #for orig_forward # qaic_config = {"num_kv_heads_repeat": TS} #with head replication for orig_forward -# qaic_config = {"enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat -qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +qaic_config = { + "enable_blocking": True, + "blocking_mode": "h", + "num_kv_heads_repeat": TS, +} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} # for KV blocking # model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" model_path = ( @@ -38,7 +41,6 @@ qpc_path = qeff_model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - enable_mla=enable_mla, mla_absorption_config=mla_absorption_config, mxfp6_matmul=True, mxint8_kv_cache=False, diff --git a/examples/kimi_k2/run_kimik2.py b/examples/kimi_k2/run_kimik2.py index c33e5cd4b4..d1c5b1abe5 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/kimi_k2/run_kimik2.py @@ -15,8 +15,7 @@ prompt = "Once upon a time," num_hidden_layers = 2 TS = 4 -enable_mla = True -mla_absorption_config = {"enable": False, "online": False} +mla_absorption_config = {"cache_compressed": False, "absorption": False, "online": False} # qaic_config = None #for orig_forward # qaic_config = {"num_kv_heads_repeat": TS} #with head replication for orig_forward # qaic_config = {"enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat @@ -47,7 +46,7 @@ qeff_model = QEFFAutoModelForCausalLM(model) qeff_model.transform(ctx_len=CTX_LEN, seq_len=PREFILL_SEQ_LEN, bs=1, num_devices=TS, qaic_config=qaic_config) -qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config) +qeff_model.mla(mla_absorption_config=mla_absorption_config) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -82,8 +81,8 @@ x = (ckv, k_pe) compressed_kvs.append(x) - -if enable_mla: +cache_compressed = mla_absorption_config.get("cache_compressed", False) +if cache_compressed: inputs["compressed_kvs"] = compressed_kvs else: inputs["past_key_values"] = past_key_values @@ -101,7 +100,7 @@ "input_ids": next_token_id, "position_ids": position_ids, } - if enable_mla: + if cache_compressed: decode_inputs["compressed_kvs"] = qeff_out["past_key_values"] else: decode_inputs["past_key_values"] = qeff_out["past_key_values"] From c966cc9ebdae11fb5942862afc162eac7d4c5314 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 27 Apr 2026 20:30:40 +0530 Subject: [PATCH 38/51] address review comments Signed-off-by: Mamta Singh --- .../blocking/blocked_attention_forwards.py | 2 +- QEfficient/blocking/blocking_configurator.py | 15 ++-- .../models/deepseek_v3/modeling_deepseek.py | 41 +++++++++-- .../transformers/models/modeling_auto.py | 72 +++++++++---------- .../transformers/models/pytorch_transforms.py | 1 + QEfficient/utils/constants.py | 8 ++- examples/kimi_k2/README.md | 28 ++++++++ examples/kimi_k2/export_kimik2.py | 25 ++++--- .../run_kimik2.py | 32 +++++---- 9 files changed, 140 insertions(+), 84 deletions(-) create mode 100644 examples/kimi_k2/README.md rename examples/{kimi_k2 => text_generation}/run_kimik2.py (76%) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 4da283b462..f5de76f37a 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -858,7 +858,7 @@ def blocked_kv_mla_attention_forward( skip_kv = True current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device) - ctx_len = compressed_kvs.layers[0].ckv.shape[2] + ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2] kv_block_size = -(-ctx_len // num_kv_blocks) position_ids = cache_kwargs.get("position_ids") diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index 0e47154f58..deed73a7bf 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -18,10 +18,7 @@ from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, BlockingMode from QEfficient.utils import get_attr_or_key, require_value -from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD - -FP16_BYTES = 2 -DEFAULT_NUM_HEADS = 64 +from QEfficient.utils.constants import DEFAULT_NUM_HEADS, FP16_BYTES, KV_LORA_RANK, ROPE_DIM, VTCM_SIZE_THRESHOLD def _infer_head_dim(model_config: Any, num_heads: int) -> int: @@ -96,14 +93,14 @@ def block_candidates_generator(max_length: int) -> List[int]: def matmul1_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: """Bytes for [1,num_heads,q,kv] x [1,1,kv,512] -> [1,num_heads,q,512] in fp16.""" elems_a = num_heads * q_len * kv_block_size - elems_b = kv_block_size * 512 - elems_out = num_heads * q_len * 512 + elems_b = kv_block_size * KV_LORA_RANK + elems_out = num_heads * q_len * KV_LORA_RANK return FP16_BYTES * (elems_a + elems_b + elems_out) def matmul2_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: """Bytes for [1,num_heads,q,576] x [1,1,576,kv] -> [1,num_heads,q,kv] in fp16.""" - elems_a = num_heads * q_len * 576 + elems_a = num_heads * q_len * (KV_LORA_RANK + ROPE_DIM) elems_b = 576 * kv_block_size elems_out = num_heads * q_len * kv_block_size return FP16_BYTES * (elems_a + elems_b + elems_out) @@ -134,9 +131,9 @@ def max_kv_block_size( # B_elems = kv*512 # C_elems = num_heads*q_len*512 # Enforce A_elems + B_elems + C_elems <= max_elems - c1_elems = num_heads * q_len * 512 + c1_elems = num_heads * q_len * KV_LORA_RANK rem1 = max_elems - c1_elems - den1 = num_heads * q_len + 512 # kv coefficient from A_elems + B_elems + den1 = num_heads * q_len + KV_LORA_RANK # kv coefficient from A_elems + B_elems k1 = rem1 // den1 if rem1 >= 0 else -1 # Matmul2 elements: diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 1b67b3df9f..0e1f270e20 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -23,7 +23,7 @@ from QEfficient.customop.rms_norm import CustomRMSNormFunc from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.constants import MAX_POSITION_EMBEDDINGS, MIN_MASKED_ATTENTION_VALUE def rotate_half(x): @@ -239,7 +239,7 @@ def __qeff_init__( per_head_k_up = ( self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) ) - per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) @@ -972,7 +972,7 @@ def __qeff_init__(self): } self.rotary_emb = DeepseekV3YarnRotaryEmbedding( self.config.qk_rope_head_dim, - max_position_embeddings=32 * 1024, + max_position_embeddings=MAX_POSITION_EMBEDDINGS, scaling_factor=scaling_factor, base=self.config.rope_theta, **kwargs, @@ -992,6 +992,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1010,7 +1011,6 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - mla_absorption = getattr(self, "mla_absorption_config", None) if mla_absorption is not None: cache_compressed = mla_absorption.get("cache_compressed", False) else: @@ -1116,6 +1116,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + mla_absorption = getattr(self, "mla_absorption", None) outputs = self.model( input_ids=input_ids, @@ -1130,6 +1131,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + mla_absorption=mla_absorption, **kwargs, ) @@ -1158,3 +1160,34 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + def get_dummy_pkv_cache(self, config, batch_size, seq_len): + mla_absorption = getattr(self, "mla_absorption", None) + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + + dummy_cache = [[] for _ in range(config.num_hidden_layers)] + if cache_compressed: + for layer in self.model.layers: + if layer is not None: + num_heads = layer.self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( + self.model.config.kv_lora_rank + config.qk_rope_head_dim + ) + cache_shape_1 = (batch_size, num_heads, seq_len, config.kv_lora_rank) + cache_shape_2 = (batch_size, num_heads, seq_len, config.qk_rope_head_dim) + else: + cache_shape_1 = ( + batch_size, + config.num_attention_heads, + seq_len, + config.qk_nope_head_dim + config.qk_rope_head_dim, + ) + cache_shape_2 = (batch_size, config.num_attention_heads, seq_len, config.v_head_dim) + + for i in range(config.num_hidden_layers): + dummy_cache[i].append(torch.zeros(cache_shape_1, dtype=config.torch_dtype)) + dummy_cache[i].append(torch.zeros(cache_shape_2, dtype=config.torch_dtype)) + + return dummy_cache diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 976f99bbcb..4820715d7e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2765,12 +2765,6 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [] - def mla( - self, - mla_absorption_config: Optional[Dict[str, bool]] = None, - ): - setattr(self.model.model, "mla_absorption_config", mla_absorption_config) - def prefill( self, enable: Optional[bool] = True, @@ -2878,6 +2872,10 @@ def __init__( self.ccl_enabled = False if qaic_config: self.ccl_enabled = qaic_config.get("ccl_enabled", False) + if mla_absorption := qaic_config.get("mla_absorption", None): + self.hash_params["mla_absorption"] = mla_absorption + # setattr(self.model.model, "mla_absorption", mla_absorption) + setattr(self.model, "mla_absorption", mla_absorption) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached @@ -3090,11 +3088,6 @@ def export( ) enable_chunking = kwargs.get("enable_chunking", False) - # TODO: HACK handle better - if mla_absorption_config := kwargs.get("mla_absorption_config", None): - self.hash_params["mla_absorption_config"] = mla_absorption_config - setattr(self.model.model, "mla_absorption_config", mla_absorption_config) - if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): if prefill_only: self.prefill(enable=True) @@ -3206,45 +3199,39 @@ def export( output_names.append(f"past_{kv}.{i}_RetainedState") if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): - if mla_absorption_config is not None: - cache_compressed = mla_absorption_config.get("cache_compressed", False) + mla_absorption = kwargs.get("mla_absorption", None) + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) else: cache_compressed = False + pkv_cache = self.model.get_dummy_pkv_cache( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) if cache_compressed: - for lay in self.model.model.layers: - if lay is not None: - num_heads = lay.self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( - self.model.config.kv_lora_rank + self.model.config.qk_rope_head_dim - ) - example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} output_names = [v for v in output_names if "past" not in v] example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): - ckv = torch.zeros((bs, num_heads, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32) - k_pe = torch.zeros( - (bs, num_heads, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32 + example_inputs["compressed_kvs"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + example_inputs["compressed_kvs"][i].append( + torch.zeros(pkv_cache[0][1].shape, dtype=self.model.config.torch_dtype) ) - example_inputs["compressed_kvs"][i].append(ckv) - example_inputs["compressed_kvs"][i].append(k_pe) dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} output_names.append(f"compressed_kv.{i}_RetainedState") output_names.append(f"k_pe.{i}_RetainedState") - else: - cache_shape_k = ( - 1, - self.model.config.num_attention_heads, - seq_len, - self.model.config.qk_nope_head_dim + self.model.config.qk_rope_head_dim, - ) - cache_shape_v = (1, self.model.config.num_attention_heads, seq_len, self.model.config.v_head_dim) example_inputs["past_key_values"] = [[] for _ in range(self.num_layers)] for i in range(self.num_layers): - example_inputs["past_key_values"][i].append(torch.zeros(cache_shape_k, dtype=torch.float32)) - example_inputs["past_key_values"][i].append(torch.zeros(cache_shape_v, dtype=torch.float32)) + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][1].shape, dtype=self.model.config.torch_dtype) + ) if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) @@ -3422,7 +3409,7 @@ def compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, - mla_absorption_config: Optional[Dict[str, bool]] = None, + mla_absorption: Optional[Dict[str, bool]] = None, **compiler_options, ) -> str: """ @@ -3467,6 +3454,11 @@ def compile( the decode stage. If None, compiles for both stages. Default is None. use_onnx_subfunctions: bool, optional whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + mla_absorption: Dict[str, bool], optional + Configuration dictionary for multi-head latent Attention (MLA) absorption behavior. + - "cache_compressed" (bool): If True, compresses kvs are cached to save memory. + - "absorption" (bool): If True, enables absorption of attention matrices for efficiency. + - "online" (bool): If True, applies MLA absorption on device during inference **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -3504,12 +3496,12 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ - if mla_absorption_config is not None: - cache_compressed = mla_absorption_config.get("cache_compressed", False) + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) else: cache_compressed = False - if mla_absorption_config is not None and not cache_compressed: - logger.warning("mla_absorption_config will be ignored as cache_compressed is set to False") + if mla_absorption is not None and not cache_compressed: + logger.warning("mla_absorption will be ignored as cache_compressed is set to False") if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: logger.warning( "`kv_cache_batch_size` or `full_batch_size` is being passed" @@ -3667,7 +3659,7 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, - mla_absorption_config=mla_absorption_config, + mla_absorption=mla_absorption, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 3597cac041..5ff06e6443 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1026,6 +1026,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "DeepseekV3ForCausalLM": { "forward": QEffDeepseekV3ForCausalLM.forward, "get_submodules_for_export": QEffDeepseekV3ForCausalLM.get_submodules_for_export, + "get_dummy_pkv_cache": QEffDeepseekV3ForCausalLM.get_dummy_pkv_cache, }, "DeepseekV3Model": {"forward": QEffDeepseekV3Model.forward, "__qeff_init__": QEffDeepseekV3Model.__qeff_init__}, "DeepseekV3DecoderLayer": { diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index cc0b87b604..339e4f4dac 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -136,8 +136,12 @@ def get_models_dir(): LLAMA4_ATTENTION_CHUNK_SIZE = 8192 LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 -# Gemma3 Constant -GEMMA3_MAX_POSITION_EMBEDDINGS = 32768 +# DeepSeek Kimi-k2 Constant +MAX_POSITION_EMBEDDINGS = 32768 +FP16_BYTES = 2 +DEFAULT_NUM_HEADS = 64 +KV_LORA_RANK = 512 +ROPE_DIM = 64 # Wav2Vec2 Constant WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) diff --git a/examples/kimi_k2/README.md b/examples/kimi_k2/README.md new file mode 100644 index 0000000000..230127ebbe --- /dev/null +++ b/examples/kimi_k2/README.md @@ -0,0 +1,28 @@ +# We should be using disaggragate serving for KImi-K2 model for best performance + - Kimi-K2 model has 384/8 ratio of total_experts/experts_per_tok + - Currently We use read all experts only once always strategy in prefill-only model + - And we treat weights activtions meaning read only chosen experts for decode-only model + +# Multi-head Latent Attention(MLA) +Kimi-K2 uses Multi-head Latent Attention(MLA) which is impleneted with dual cache (for compressed_kv and k_pe) + +# Absorption +MLA has 3 configurations based on order of evaluation different matrices, to enable, mla absorption config needs to passed like this : +- No absorption : mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} +- Offline No absorption : mla_absorption = {"cache_compressed": True, "absorption": True, "online": False} +- Online absorption : mla_absorption = {"cache_compressed": True, "absorption": True, "online": True} + +mla_absorption has 3 keys: +- cache_compressed: True/False -> gets enabled if compressed KVs are cached to save memory. +- absorption: True/False -> gets enabled only when compressed cache is used, if True, enables absorption of attention matrices for efficiency. +- online: True/False -> gets enabled only when absorption is True, enables on device absorption. + +# Blocking +We have also implemented KV head replication, HEAD Blocking and KV Blocking which can be enable like this : +- For No Blocking : qaic_config = {"mla_absorption" : mla_absorption} +- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_kv_heads_repeat": TS} +- For KV blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + +- Currently Decode-Only model is giving best perf with Head Blocking and compressed cache. +- Contnuous batching is not enabled yet. \ No newline at end of file diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 891cc998cd..82fc294553 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -14,26 +14,25 @@ prompt = "Once upon a time," num_hidden_layers = 2 TS = 4 -mla_absorption_config = {"cache_compressed": True, "absorption": False, "online": False} -# qaic_config = None #for orig_forward -# qaic_config = {"num_kv_heads_repeat": TS} #with head replication for orig_forward +mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} +# qaic_config = {"mla_absorption": mla_absorption} # for No Blocking +# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking qaic_config = { + "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS, -} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat -# qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +} +# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat -# model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" -model_path = ( - "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" -) +model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True + model_name, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) -tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) -qeff_model = QEFFAutoModelForCausalLM(model) +qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) prefill_seq_len = 1 ctx_len = 16 * 1024 @@ -41,7 +40,7 @@ qpc_path = qeff_model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - mla_absorption_config=mla_absorption_config, + mla_absorption=mla_absorption, mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=TS, diff --git a/examples/kimi_k2/run_kimik2.py b/examples/text_generation/run_kimik2.py similarity index 76% rename from examples/kimi_k2/run_kimik2.py rename to examples/text_generation/run_kimik2.py index d1c5b1abe5..b6f8d821f3 100644 --- a/examples/kimi_k2/run_kimik2.py +++ b/examples/text_generation/run_kimik2.py @@ -15,20 +15,23 @@ prompt = "Once upon a time," num_hidden_layers = 2 TS = 4 -mla_absorption_config = {"cache_compressed": False, "absorption": False, "online": False} -# qaic_config = None #for orig_forward -# qaic_config = {"num_kv_heads_repeat": TS} #with head replication for orig_forward -# qaic_config = {"enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat -qaic_config = {"enable_blocking": True, "blocking_mode": "kv"} # for KV blocking - -# model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/" -model_path = ( - "/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd" -) +mla_absorption = {"cache_compressed": False, "absorption": False, "online": False} +# qaic_config = {"mla_absorption": mla_absorption} # for No Blocking +# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +qaic_config = { + "mla_absorption": mla_absorption, + "enable_blocking": True, + "blocking_mode": "h", + "num_kv_heads_repeat": TS, +} +# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + +model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True + model_name, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True ) -tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) PREFILL_SEQ_LEN = 32 CTX_LEN = 8192 @@ -44,9 +47,8 @@ # out = model(**inputs) # predictions = torch.argmax(out.logits, dim=-1) -qeff_model = QEFFAutoModelForCausalLM(model) +qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) qeff_model.transform(ctx_len=CTX_LEN, seq_len=PREFILL_SEQ_LEN, bs=1, num_devices=TS, qaic_config=qaic_config) -qeff_model.mla(mla_absorption_config=mla_absorption_config) inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -81,7 +83,7 @@ x = (ckv, k_pe) compressed_kvs.append(x) -cache_compressed = mla_absorption_config.get("cache_compressed", False) +cache_compressed = mla_absorption.get("cache_compressed", False) if cache_compressed: inputs["compressed_kvs"] = compressed_kvs else: From 01867eae8dbaa1c94ac3c95dc295ce98e7094c2b Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 28 Apr 2026 00:50:54 +0530 Subject: [PATCH 39/51] fix mla_absorption_config Signed-off-by: Mamta Singh --- QEfficient/base/modeling_qeff.py | 10 +++++----- QEfficient/transformers/modeling_utils.py | 2 +- QEfficient/transformers/models/modeling_auto.py | 11 ++--------- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 30ee4d5696..e9213761d9 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -287,7 +287,7 @@ def _export( """ # TODO: Hack for retain_full_kv, handle this outside export_kwargs.pop("retain_full_kv", None) - export_kwargs.pop("mla_absorption_config", None) + export_kwargs.pop("mla_absorption", None) onnx_path = export_dir / f"{self.model_name}.onnx" # Return early if ONNX already exists @@ -392,7 +392,7 @@ def get_onnx_path( offload_pt_weights: Optional[bool] = True, use_onnx_subfunctions: Optional[bool] = False, retain_full_kv: Optional[bool] = False, - mla_absorption_config: Optional[Dict[str, bool]] = None, + mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, **compiler_options, ): @@ -400,7 +400,7 @@ def get_onnx_path( "offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions, "retain_full_kv": retain_full_kv, - "mla_absorption_config": mla_absorption_config, + "mla_absorption": mla_absorption, } if prefill_only: @@ -498,7 +498,7 @@ def _compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, - mla_absorption_config: Optional[Dict[str, bool]] = None, + mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, specialization_module_name: Optional[str] = None, **compiler_options, @@ -539,7 +539,7 @@ def _compile( offload_pt_weights, use_onnx_subfunctions, retain_full_kv, - mla_absorption_config, + mla_absorption, num_devices=mdp_ts_num_devices, qaic_config=qaic_config, **compiler_options, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index a29d0e0966..f9d7fe62cd 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -196,7 +196,7 @@ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # This is for supporting different modelling classes specially written for prefill-only model -SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"} +SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "kimi_k2", "kimi_k25"} _PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4820715d7e..cc1921d2aa 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2792,12 +2792,14 @@ def __update_prefill_transform( retain_full_kv: Optional[bool] = False, ): if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) if enable_chunking: self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) else: self.model, tf = PrefillOnlyTransform.apply(self.model) else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) if retain_full_kv: self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) else: @@ -2874,7 +2876,6 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) if mla_absorption := qaic_config.get("mla_absorption", None): self.hash_params["mla_absorption"] = mla_absorption - # setattr(self.model.model, "mla_absorption", mla_absorption) setattr(self.model, "mla_absorption", mla_absorption) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached @@ -3088,14 +3089,6 @@ def export( ) enable_chunking = kwargs.get("enable_chunking", False) - if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): - if prefill_only: - self.prefill(enable=True) - self.hash_params["prefill_only"] = True - else: - self.prefill(enable=False) - self.hash_params.pop("prefill_only", None) - # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: From 4f7d28c4a9fbf64bc39efb8c1e34f57995048f64 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 28 Apr 2026 01:09:27 +0530 Subject: [PATCH 40/51] added head expansion for kv blocking Signed-off-by: Onkar Chougule --- QEfficient/blocking/blocked_attention_forwards.py | 14 +++++++++++++- QEfficient/transformers/models/modeling_auto.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index f5de76f37a..9bce4336bb 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -893,9 +893,12 @@ def blocked_kv_mla_attention_forward( absorption = mla_absorption.get("absorption", False) else: absorption = False - if absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) + k_heads, q_heads = krope_nope.shape[1], query.shape[1] + num_heads_to_repeat = q_heads - k_heads + repeated_k = krope_nope[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank) + krope_nope = torch.cat((krope_nope, repeated_k), dim=1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) @@ -909,7 +912,15 @@ def blocked_kv_mla_attention_forward( skip_future, ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] else: + k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] + num_heads_to_repeat = q_heads - k_heads + repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank) + compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) + + repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim) + k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) + krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) @@ -923,6 +934,7 @@ def blocked_kv_mla_attention_forward( skip_future, ) + attn_output = torch.matmul(output, per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous() attn_weights = None diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4820715d7e..0f3350fa18 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3199,7 +3199,7 @@ def export( output_names.append(f"past_{kv}.{i}_RetainedState") if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): - mla_absorption = kwargs.get("mla_absorption", None) + mla_absorption = self.model.qaic_config.get("mla_absorption", None) if mla_absorption is not None: cache_compressed = mla_absorption.get("cache_compressed", False) else: From b05aa066eb0911b34d508308a144fb409484f81c Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 28 Apr 2026 02:07:32 +0530 Subject: [PATCH 41/51] fixed kv with kv replication Signed-off-by: Onkar Chougule --- QEfficient/blocking/blocked_attention_forwards.py | 14 ++++++++++++-- QEfficient/transformers/models/modeling_auto.py | 4 ++-- examples/kimi_k2/export_kimik2.py | 1 - 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 9bce4336bb..6b03286503 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -841,7 +841,7 @@ def blocked_kv_mla_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize result tensor batch_size, num_heads, seq_len, _ = query.shape - output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device) + output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype) if hasattr(module, "config"): mask_dtype = module.config.torch_dtype @@ -854,9 +854,10 @@ def blocked_kv_mla_attention_forward( (batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE), device=query.device, + dtype=query.dtype, ) skip_kv = True - current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device) + current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device, dtype=query.dtype) ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2] kv_block_size = -(-ctx_len // num_kv_blocks) @@ -893,6 +894,15 @@ def blocked_kv_mla_attention_forward( absorption = mla_absorption.get("absorption", False) else: absorption = False + + k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] + num_heads_to_repeat = q_heads - k_heads + repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank) + compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) + + repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim) + k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) + if absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) k_heads, q_heads = krope_nope.shape[1], query.shape[1] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c4096593e6..ae7100f713 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3192,8 +3192,8 @@ def export( output_names.append(f"past_{kv}.{i}_RetainedState") if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): - mla_absorption = self.model.qaic_config.get("mla_absorption", None) - if mla_absorption is not None: + if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: + mla_absorption = self.model.qaic_config["mla_absorption"] cache_compressed = mla_absorption.get("cache_compressed", False) else: cache_compressed = False diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 82fc294553..416d8133ad 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -40,7 +40,6 @@ qpc_path = qeff_model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - mla_absorption=mla_absorption, mxfp6_matmul=True, mxint8_kv_cache=False, num_devices=TS, From c24d8510d162a4efd6a6a3d0b528d69177f387d3 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 28 Apr 2026 02:29:20 +0530 Subject: [PATCH 42/51] minor fix Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ae7100f713..e6561178a4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3489,7 +3489,8 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ - if mla_absorption is not None: + if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: + mla_absorption = self.model.qaic_config["mla_absorption"] cache_compressed = mla_absorption.get("cache_compressed", False) else: cache_compressed = False From bdbf50c094c9eb9aff8a26b694dc1fdaadae08ef Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Tue, 28 Apr 2026 15:51:48 +0530 Subject: [PATCH 43/51] Added all changes of layer wise for kimi model Signed-off-by: Abhishek Kumar Singh --- QEfficient/base/modeling_qeff.py | 438 +++++++------- QEfficient/base/onnx_transforms.py | 71 ++- .../blocking/blocked_attention_forwards.py | 29 +- .../models/deepseek_v3/modeling_deepseek.py | 81 ++- QEfficient/utils/compile_layerwise.py | 238 ++++++++ QEfficient/utils/inference.py | 206 +++++++ QEfficient/utils/layerwise_pipeline.py | 546 ++++++++++++++++++ run.py | 227 ++++++++ .../test_audio_embedding_models.py | 3 - .../test_speech_seq2seq_models.py | 2 - .../causal_lm_models/check_causal_models.py | 1 - .../test_causal_lm_blocking_hqkv.py | 6 - .../causal_lm_models/test_causal_lm_models.py | 4 - .../causal_lm_models/test_causal_lm_pl1.py | 6 - .../test_causal_tlm_models.py | 6 - .../causal_lm_models/test_fp16_causal_lm.py | 3 - .../image_text_to_text/test_custom_dtype.py | 2 - .../test_causal_lm_blocking_subfunction.py | 3 - .../subfunction/test_subfunction_vlm.py | 4 - 19 files changed, 1599 insertions(+), 277 deletions(-) create mode 100644 QEfficient/utils/compile_layerwise.py create mode 100644 QEfficient/utils/inference.py create mode 100644 QEfficient/utils/layerwise_pipeline.py create mode 100644 run.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e9213761d9..4dad077d8c 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,8 +8,6 @@ import gc import inspect import logging -import shutil -import subprocess import warnings from abc import ABC, abstractmethod from pathlib import Path @@ -20,13 +18,13 @@ from QEfficient.base.onnx_transforms import ( BaseOnnxTransform, - FP16ClipTransform, + CustomOpTransform, OnnxTransformPipeline, + RenameFunctionOutputsTransform, SplitTensorsTransform, ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform -from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.models.pytorch_transforms import ( BlockingAttentionTransform, @@ -34,15 +32,10 @@ ) from QEfficient.utils import ( constants, - create_json, create_model_params, dump_qconfig, - generate_mdp_partition_config, get_attr_or_key, - hash_dict_params, - load_json, require_value, - to_named_specializations, ) from QEfficient.utils.export_utils import export_wrapper @@ -59,6 +52,10 @@ class QEFFBaseModel(ABC): :_onnx_transforms: ONNX transformations to be applied after ONNX export. """ + _start = 0 + _end = 1 + _total_layers = None + _pytorch_transforms: List[PytorchTransform] _onnx_transforms = [BaseOnnxTransform] @@ -285,6 +282,13 @@ def _export( instance using from_pretrained() for re-export. """ + + idx = int(QEFFBaseModel._start) + # agent change start: generalized layerwise window + end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1)) + if end_idx <= idx: + raise ValueError(f"Invalid export window: start={idx}, end={end_idx}") + # TODO: Hack for retain_full_kv, handle this outside export_kwargs.pop("retain_full_kv", None) export_kwargs.pop("mla_absorption", None) @@ -300,12 +304,44 @@ def _export( export_dir.mkdir(parents=True, exist_ok=True) + # Setup temporary paths + tmp_onnx_dir = export_dir / "onnx_layerwise_tmp" + tmp_onnx_dir.mkdir(parents=True, exist_ok=True) + + output_name = [] + output_name.append("logits") + # agent change start: emit retained states for all layers in current export window + for layer_idx in range(idx, end_idx): + output_name.append(f"compressed_kv.{layer_idx}_InternalRetainedState") + output_name.append(f"k_pe.{layer_idx}_InternalRetainedState") + + if idx >= 1: + z = example_inputs.pop("input_ids") + # z = example_inputs["input_ids"] + ################### model_dependent ############################ + inputs_embeds = torch.rand(z.shape[0], z.shape[1], 7168, device=z.device, dtype=torch.float16) + # example_inputs[f"layer_{QEFFBaseModel._start}/inputs_embeds"] = inputs_embeds + # dynamic_axes[f"layer_{QEFFBaseModel._start}/inputs_embeds"] = dynamic_axes.pop("input_ids") + example_inputs["inputs_embeds"] = inputs_embeds + dynamic_axes["inputs_embeds"] = dynamic_axes.pop("input_ids") + # Create input_names from example_inputs + # example_inputs[f"layer_{QEFFBaseModel._start}/position_ids"] = example_inputs.pop("position_ids") + # dynamic_axes[f"layer_{QEFFBaseModel._start}/position_ids"] = dynamic_axes.pop("position_ids") + + window_size = end_idx - idx + if "compressed_kvs" in example_inputs: + example_inputs["compressed_kvs"] = [ + val for i, val in enumerate(example_inputs["compressed_kvs"]) if i < window_size + ] + # Create input_names from example_inputs input_names = [] for param in inspect.signature(self.model.forward).parameters: if param in example_inputs: if param == "past_key_values": for i in range(len(example_inputs["past_key_values"])): + # example_inputs["past_key_values"] = [ + # val for i, val in enumerate(example_inputs["past_key_values"]) if i < window_size] if len(example_inputs["past_key_values"][0]) == 2: input_names.extend([f"past_key.{i}", f"past_value.{i}"]) elif len(example_inputs["past_key_values"][0]) == 4: @@ -322,67 +358,68 @@ def _export( f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {len(example_inputs['past_key_values'][0])}" ) elif param == "compressed_kvs": - for i in range(len(example_inputs["compressed_kvs"])): - input_names.extend( - [ - f"compressed_kv.{i}", - ] - ) - input_names.extend( - [ - f"k_pe.{i}", - ] - ) + if len(example_inputs["compressed_kvs"][0]) == 2: + for layer_offset in range(len(example_inputs["compressed_kvs"])): + layer_idx = idx + layer_offset + input_names.extend([f"compressed_kv.{layer_idx}", f"k_pe.{layer_idx}"]) + else: + for i in range(len(example_inputs["compressed_kvs"])): + input_names.extend( + [ + f"compressed_kv.{i}", + ] + ) + input_names.extend( + [ + f"k_pe.{i}", + ] + ) else: input_names.append(param) + dynamic_axes = {k: v for k, v in dynamic_axes.items() if k in input_names} + + import os + import time - try: + layerwise_dir = export_dir / "onnx_layerwise_tmp" + start_time = time.time() + + # example_inputs["layer_indices_to_run"] = [i] + current_layer_dir = layerwise_dir / f"layer_{idx}_{end_idx}" + current_layer_dir.mkdir(parents=True, exist_ok=True) + + layer_onnx_path = str(current_layer_dir / f"{self.model_name}_layer_{idx}_{end_idx}.onnx") + layer_onnx_path_tmp = str(current_layer_dir / f"{self.model_name}_layer_tmp_{idx}_{end_idx}.onnx") + if not os.path.isfile(layer_onnx_path): torch.onnx.export( self.model, (example_inputs,), - str(onnx_path), + layer_onnx_path_tmp, input_names=input_names, - output_names=output_names, + output_names=output_name, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, **export_kwargs, ) - logger.info("PyTorch export successful") - _ = self._offload_model_weights(offload_pt_weights) - model = onnx.load(onnx_path, load_external_data=False) - - needs_external_tensor_data = any( - transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform) - ) - transform_kwargs = { - "onnx_base_dir": str(export_dir) if needs_external_tensor_data else None, - "model_name": self.model_name, - } - if onnx_transform_kwargs is not None: - transform_kwargs.update(onnx_transform_kwargs) - - onnx_transforms = OnnxTransformPipeline(transforms=self._onnx_transforms) - model, transformed = onnx_transforms.apply(model, **transform_kwargs) - - # Add metadata to the model - model.metadata_props.append( - onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names())) - ) - logger.info("ONNX transforms applied") - - onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp") - onnx.save(model, onnx_path_tmp) - onnx_path_tmp.replace(onnx_path) - del model - gc.collect() - logger.info("Transformed ONNX saved") - - except Exception as e: - logger.error(f"ONNX export or transforms failed: {e}") - raise e + total_end = time.time() + print(f"\nTotal export time: {total_end - start_time:.2f} seconds") + + model = onnx.load(layer_onnx_path_tmp, load_external_data=False) + # print(model.functions) + transform_kwargs = { + "onnx_base_dir": str(current_layer_dir), + "model_name": self.model_name, + "layer_idx": idx, + } + _onnx_transforms = [SplitTensorsTransform, CustomOpTransform, RenameFunctionOutputsTransform] + onnx_transforms = OnnxTransformPipeline(transforms=_onnx_transforms) + model, transformed = onnx_transforms.apply(model, **transform_kwargs) + onnx.save(model, layer_onnx_path_tmp) + self.onnx_path = layer_onnx_path_tmp + import pdb - self.onnx_path = onnx_path - return onnx_path + pdb.set_trace() + return layer_onnx_path_tmp def get_onnx_path( self, @@ -545,145 +582,146 @@ def _compile( **compiler_options, ) ) - compile_dir = Path(compile_dir or onnx_path.parent) - qpc_path = compile_dir / "qpc" - if not onnx_path.is_file(): - raise FileNotFoundError(f"ONNX file not found at: {onnx_path}") - - if enable_qnn: - if compiler_options: - logger.warning( - f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}" - ) - - self.qpc_path = qnn_compile( - onnx_path=onnx_path, - qpc_base_path=compile_dir, - specializations=specializations, - custom_io=custom_io, - device_group=list(range(mdp_ts_num_devices)), - num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), - mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL), - mxint8=mxint8_kv_cache, - qnn_config=qnn_config, - ) - - return self.qpc_path - - command = ( - constants.COMPILER - + [ - f"-aic-hw-version={compiler_options.pop('aic_hw_version', compiler_options.pop('aic-hw-version', constants.DEFAULT_AIC_HW_VERSION))}" - ] - + [f"-m={onnx_path}"] - ) - - # MDP partition config: prioritize dump over load - mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) - mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) - mdp_ts_json = None - - if mdp_dump_json_path: - if mdp_ts_json_path: - logger.warning( - "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" - ) - command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") - elif mdp_ts_json_path: - command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") - mdp_ts_json = load_json(str(mdp_ts_json_path)) - elif mdp_ts_num_devices > 1: - # Generate mdp config only if neither dump nor load is provided and num_devices > 1 - mdp_ts_json = generate_mdp_partition_config( - mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) - ) - mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json" - create_json(str(mdp_ts_json_path), mdp_ts_json) - command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") - - for key, value in compiler_options.items(): - option = "-" + key.replace("_", "-") - if isinstance(value, bool): - if value: - command.append(option) - continue - command.append(f"{option}={value}") - - if use_onnx_subfunctions: - logger.info("Using ONNX subfunctions for compilation.") - command.append("-sub-functions") - - compile_hash_params = { - "command": command, - "specializations": specializations, - "custom_io": custom_io, - "mdp_ts_num_devices": mdp_ts_num_devices, - "mdp_ts_json": mdp_ts_json, - "num_speculative_tokens": num_speculative_tokens, - "prefill_only": prefill_only, - } - compile_hash = hash_dict_params(compile_hash_params) - - compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash) - qpc_path = compile_dir / "qpc" - qpc_path.mkdir(parents=True, exist_ok=True) - - if qpc_path.is_dir(): - if (qpc_path / "programqpc.bin").is_file(): - self.qpc_path = qpc_path - return qpc_path - # Probably compilation failure last time, delete directory to start over - shutil.rmtree(qpc_path) - - # Write the generated MDP partition config file (not if user provided it) - - # Write specializations.json file - if specializations is not None: - specializations_json = compile_dir / "specializations.json" - specializations_data = { - "specializations": to_named_specializations(specializations, module_name=specialization_module_name) - } - create_json(str(specializations_json), specializations_data) - command.append(f"-network-specialization-config={specializations_json}") - - # Write custom_io.yaml file - model_in_bfloat16 = hasattr(self, "config") and (self.config.torch_dtype == torch.bfloat16) - pkv_in_bfloat16 = (custom_io is not None) and any( - "past_" in key and "bfloat16" in value for key, value in custom_io.items() - ) - if custom_io is not None: - custom_io_yaml = compile_dir / "custom_io.yaml" - with open(custom_io_yaml, "w") as fp: - for io_name, dtype in custom_io.items(): - fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n") - if model_in_bfloat16 and pkv_in_bfloat16: - logger.warning( - "Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile." - ) - else: - command.append(f"-custom-IO-list-file={custom_io_yaml}") - - command.append(f"-aic-binary-dir={qpc_path}") - logger.info(f"Running compiler: {' '.join(command)}") - - try: - subprocess.run(command, capture_output=True, check=True) - except subprocess.CalledProcessError as e: - raise RuntimeError( - "\n".join( - [ - "Compilation failed!", - f"Compiler command: {e.cmd}", - f"Compiler exitcode: {e.returncode}", - "Compiler stderr:", - e.stderr.decode(), - ] - ) - ) - # Dump JSON file with hashed parameters - hashed_compile_params_path = compile_dir / "hashed_compile_params.json" - create_json(hashed_compile_params_path, compile_hash_params) - logger.info("Hashed parameters exported successfully.") - - self.qpc_path = qpc_path - return qpc_path + return onnx_path + # compile_dir = Path(compile_dir or onnx_path.parent) + # qpc_path = compile_dir / "qpc" + # if not onnx_path.is_file(): + # raise FileNotFoundError(f"ONNX file not found at: {onnx_path}") + + # if enable_qnn: + # if compiler_options: + # logger.warning( + # f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}" + # ) + + # self.qpc_path = qnn_compile( + # onnx_path=onnx_path, + # qpc_base_path=compile_dir, + # specializations=specializations, + # custom_io=custom_io, + # device_group=list(range(mdp_ts_num_devices)), + # num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), + # mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL), + # mxint8=mxint8_kv_cache, + # qnn_config=qnn_config, + # ) + + # return self.qpc_path + + # command = ( + # constants.COMPILER + # + [ + # f"-aic-hw-version={compiler_options.pop('aic_hw_version', compiler_options.pop('aic-hw-version', constants.DEFAULT_AIC_HW_VERSION))}" + # ] + # + [f"-m={onnx_path}"] + # ) + + # # MDP partition config: prioritize dump over load + # mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) + # mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) + # mdp_ts_json = None + + # if mdp_dump_json_path: + # if mdp_ts_json_path: + # logger.warning( + # "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" + # ) + # command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") + # elif mdp_ts_json_path: + # command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") + # mdp_ts_json = load_json(str(mdp_ts_json_path)) + # elif mdp_ts_num_devices > 1: + # # Generate mdp config only if neither dump nor load is provided and num_devices > 1 + # mdp_ts_json = generate_mdp_partition_config( + # mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) + # ) + # mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json" + # create_json(str(mdp_ts_json_path), mdp_ts_json) + # command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") + + # for key, value in compiler_options.items(): + # option = "-" + key.replace("_", "-") + # if isinstance(value, bool): + # if value: + # command.append(option) + # continue + # command.append(f"{option}={value}") + + # if use_onnx_subfunctions: + # logger.info("Using ONNX subfunctions for compilation.") + # command.append("-sub-functions") + + # compile_hash_params = { + # "command": command, + # "specializations": specializations, + # "custom_io": custom_io, + # "mdp_ts_num_devices": mdp_ts_num_devices, + # "mdp_ts_json": mdp_ts_json, + # "num_speculative_tokens": num_speculative_tokens, + # "prefill_only": prefill_only, + # } + # compile_hash = hash_dict_params(compile_hash_params) + + # compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash) + # qpc_path = compile_dir / "qpc" + # qpc_path.mkdir(parents=True, exist_ok=True) + + # if qpc_path.is_dir(): + # if (qpc_path / "programqpc.bin").is_file(): + # self.qpc_path = qpc_path + # return qpc_path + # # Probably compilation failure last time, delete directory to start over + # shutil.rmtree(qpc_path) + + # # Write the generated MDP partition config file (not if user provided it) + + # # Write specializations.json file + # if specializations is not None: + # specializations_json = compile_dir / "specializations.json" + # specializations_data = { + # "specializations": to_named_specializations(specializations, module_name=specialization_module_name) + # } + # create_json(str(specializations_json), specializations_data) + # command.append(f"-network-specialization-config={specializations_json}") + + # # Write custom_io.yaml file + # model_in_bfloat16 = hasattr(self, "config") and (self.config.torch_dtype == torch.bfloat16) + # pkv_in_bfloat16 = (custom_io is not None) and any( + # "past_" in key and "bfloat16" in value for key, value in custom_io.items() + # ) + # if custom_io is not None: + # custom_io_yaml = compile_dir / "custom_io.yaml" + # with open(custom_io_yaml, "w") as fp: + # for io_name, dtype in custom_io.items(): + # fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n") + # if model_in_bfloat16 and pkv_in_bfloat16: + # logger.warning( + # "Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile." + # ) + # else: + # command.append(f"-custom-IO-list-file={custom_io_yaml}") + + # command.append(f"-aic-binary-dir={qpc_path}") + # logger.info(f"Running compiler: {' '.join(command)}") + + # try: + # subprocess.run(command, capture_output=True, check=True) + # except subprocess.CalledProcessError as e: + # raise RuntimeError( + # "\n".join( + # [ + # "Compilation failed!", + # f"Compiler command: {e.cmd}", + # f"Compiler exitcode: {e.returncode}", + # "Compiler stderr:", + # e.stderr.decode(), + # ] + # ) + # ) + # # Dump JSON file with hashed parameters + # hashed_compile_params_path = compile_dir / "hashed_compile_params.json" + # create_json(hashed_compile_params_path, compile_hash_params) + # logger.info("Hashed parameters exported successfully.") + + # self.qpc_path = qpc_path + # return qpc_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index c27e3cc704..91c6ef3e27 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -129,17 +129,80 @@ def apply(cls, model: ModelProto) -> bool: return op_applied +class RemovePrefix(BaseOnnxTransform): + @classmethod + def apply(cls, model: ModelProto) -> bool: + graph = model.graph + renamed = False + + def strip_prefix(name: str) -> str: + parts = name.rsplit("/", 1) + return parts[1] if len(parts) == 2 else parts[0] + + input_names = [] + for i, inputs in enumerate(graph.input): + original = inputs.name + new = strip_prefix(original) + if new != original: + renamed = True + inputs.name = new + graph.input[i].name = new + input_names.append(new) + + input_name_set = set(input_names) + output_rename_map = {} + + # Rename model graph outputs and keep mapping so producer/consumer edges can be fixed. + for out in graph.output: + original = out.name + new = strip_prefix(original) + if new != original: + out.name = new + output_rename_map[original] = new + renamed = True + + for node in graph.node: + for i, out in enumerate(node.output): + if out in output_rename_map and output_rename_map[out] != out: + node.output[i] = output_rename_map[out] + renamed = True + + new_inputs = [] + for s in node.input: + # Keep node inputs in sync for renamed model outputs. + if s in output_rename_map: + new_inputs.append(output_rename_map[s]) + continue + + if s in input_name_set: + new_inputs.append(s) + continue + + replaced = s + if "/" in s: + tail = s.rsplit("/", 1)[1] + if tail in input_name_set: + replaced = tail + new_inputs.append(replaced) + + for idx in range(len(node.input)): + if node.input[idx] != new_inputs[idx]: + node.input[idx] = new_inputs[idx] + renamed = True + + return renamed + + class RenameFunctionOutputsTransform(BaseOnnxTransform): """Rename outputs of decoder-related functions for better clarity.""" @classmethod - def apply(cls, model: ModelProto) -> bool: + def apply(cls, model: ModelProto, layer_idx=0) -> bool: graph = model.graph op_type_to_func = {f.name: f for f in model.functions} decoder_patterns = ["DecoderLayer", "Block", "Layer"] renamed = False model_out_map = {v.name: i for i, v in enumerate(graph.output)} - layer_idx = 0 for node in graph.node: if any(p in node.name or p in node.op_type for p in decoder_patterns): @@ -278,7 +341,9 @@ def _set_external_data(tensor, file_name): applied[CustomOpTransform] = CustomOpTransform.apply(model) if RenameFunctionOutputsTransform in requested: - applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply(model) + applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply( + model, layer_idx=kwargs.get("layer_idx", 0) + ) if AdapterWeightsToInputsTransform in requested: applied[AdapterWeightsToInputsTransform] = AdapterWeightsToInputsTransform.apply(model, **kwargs) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 6b03286503..37f65034a3 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -841,7 +841,9 @@ def blocked_kv_mla_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize result tensor batch_size, num_heads, seq_len, _ = query.shape - output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype) + output = torch.zeros( + batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype + ) if hasattr(module, "config"): mask_dtype = module.config.torch_dtype @@ -897,17 +899,23 @@ def blocked_kv_mla_attention_forward( k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] num_heads_to_repeat = q_heads - k_heads - repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank) + repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.kv_lora_rank + ) compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) - repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim) + repeated_k_pe_block = k_pe_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + ) k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) if absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) k_heads, q_heads = krope_nope.shape[1], query.shape[1] num_heads_to_repeat = q_heads - k_heads - repeated_k = krope_nope[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank) + repeated_k = krope_nope[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank + ) krope_nope = torch.cat((krope_nope, repeated_k), dim=1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] @@ -924,13 +932,17 @@ def blocked_kv_mla_attention_forward( else: k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] num_heads_to_repeat = q_heads - k_heads - repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank) + repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.kv_lora_rank + ) compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) - - repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim) + + repeated_k_pe_block = k_pe_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + ) k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) - + krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) @@ -944,7 +956,6 @@ def blocked_kv_mla_attention_forward( skip_future, ) - attn_output = torch.matmul(output, per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous() attn_weights = None diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 0e1f270e20..57dd4793f2 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -55,7 +55,7 @@ def yarn_linear_ramp_mask(min, max, dim): if min == max: max += 0.001 # Prevent singularity - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + linear_func = (torch.arange(dim, dtype=torch.float16) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func @@ -145,9 +145,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len dim = self.dim - freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float16, device=device) / dim)) freq_inter = 1.0 / ( - self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float16, device=device) / dim) ) low, high = yarn_find_correction_range( @@ -157,11 +157,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.base, self.original_max_position_embeddings, ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float16) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len, device=device, dtype=torch.float32) + t = torch.arange(seq_len, device=device, dtype=torch.float16) freqs = torch.outer(t, inv_freq) @@ -282,14 +282,16 @@ def fused_forward_h_blocking( kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) + if compressed_kvs is not None: - kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) + kva = compressed_kvs.update_ckv(kva, window_cache_layer_idx, cache_kwargs) cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + k_pe = compressed_kvs.update_k_pe(k_pe, window_cache_layer_idx, cache_kwargs) blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) @@ -347,16 +349,17 @@ def fused_forward_kv_blocking( kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) ## Write Only if compressed_kvs is not None: - compressed_kvs.write_only_ckv(kva, self.layer_idx, cache_kwargs) + compressed_kvs.write_only_ckv(kva, window_cache_layer_idx, cache_kwargs) cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) + compressed_kvs.write_only_k_pe(k_pe, window_cache_layer_idx, cache_kwargs) if mla_absorption is not None: absorption = mla_absorption.get("absorption", False) @@ -433,8 +436,10 @@ def fused_forward_orig( kva = self.kv_a_layernorm(kva) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) + if compressed_kvs is not None: - kva = compressed_kvs.update_ckv(kva, self.layer_idx, cache_kwargs) + kva = compressed_kvs.update_ckv(kva, window_cache_layer_idx, cache_kwargs) # ---- MLA absorption flags ---- if mla_absorption is not None: @@ -452,7 +457,7 @@ def fused_forward_orig( q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if compressed_kvs is not None: - k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs) + k_pe = compressed_kvs.update_k_pe(k_pe, window_cache_layer_idx, cache_kwargs) kva_expanded = ( kva.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) @@ -495,7 +500,7 @@ def fused_forward_orig( torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype), attn_weights, ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float16).to(q_pe.dtype) ## Do v_proj here attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) @@ -609,19 +614,22 @@ def forward_full_kv( query_states = torch.cat((q_nope, q_pe), -1) k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) key_states = torch.cat((k_nope, k_pe_new), -1) + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, window_cache_layer_idx, cache_kwargs + ) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float16), attn_weights ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) @@ -957,6 +965,10 @@ def forward( class QEffDeepseekV3Model(nn.Module): """Adapted DeepseekV3Model with batch_index and QEff rotary embedding.""" + _start = 0 + _end = 0 + _total_layers = None + def __qeff_init__(self): scaling_factor = self.config.rope_scaling["factor"] kwargs = { @@ -993,6 +1005,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, mla_absorption: Optional[Dict[str, bool]] = None, + layer_indices_to_run: Optional[List[int]] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1018,9 +1031,9 @@ def forward( if cache_compressed: compressed_kvs = QEffDynamicCompressedKVRopeCache.from_legacy_cache(compressed_kvs) - target_len = compressed_kvs.layers[0].ckv.shape[-2] - else: - target_len = past_key_values[0][0].shape[2] + # target_len = compressed_kvs.layers[0].ckv.shape[-2] + # else: + # target_len = past_key_values[0][0].shape[2] if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1031,7 +1044,12 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_len) + # causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_len) + start = QEffDeepseekV3Model._start + end = QEffDeepseekV3Model._end + + ctx_len = compressed_kvs.layers[0].ckv.shape[-2] + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=ctx_len) hidden_states = inputs_embeds position_embeddings = None @@ -1039,7 +1057,11 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for layer_idx, decoder_layer in enumerate(self.layers): + if layer_idx < start or layer_idx >= end: + continue + if layer_indices_to_run is not None and layer_idx not in layer_indices_to_run: + continue if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1064,7 +1086,10 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) + total_layers = getattr(QEffDeepseekV3Model, "_total_layers", len(self.layers)) + if QEffDeepseekV3Model._end == total_layers: + hidden_states = self.norm(hidden_states) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1091,7 +1116,7 @@ def get_submodules_for_export(self) -> Type[nn.Module]: This method should return the *class object* (not an instance). Downstream code can use this to find/build subfunctions for repeated blocks. """ - return {self.model.layers[0].__class__} + return {self.model.layers[QEffDeepseekV3Model._start].__class__} def forward( self, @@ -1109,6 +1134,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, + layer_indices_to_run: Optional[List[int]] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1132,13 +1158,18 @@ def forward( return_dict=return_dict, cache_position=cache_position, mla_absorption=mla_absorption, + layer_indices_to_run=layer_indices_to_run, **kwargs, ) hidden_states = outputs[0] - logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - logits = self.lm_head(hidden_states).float() + total_layers = getattr(QEffDeepseekV3Model, "_total_layers", len(self.model.layers)) + if QEffDeepseekV3Model._end < total_layers: + logits = hidden_states + else: + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() loss = None if labels is not None: diff --git a/QEfficient/utils/compile_layerwise.py b/QEfficient/utils/compile_layerwise.py new file mode 100644 index 0000000000..81375472ae --- /dev/null +++ b/QEfficient/utils/compile_layerwise.py @@ -0,0 +1,238 @@ +import argparse +import os +import re +import signal +import subprocess +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +# ===================================================== +# CONFIG +# ===================================================== + +MAX_RETRIES = 1 # retries don't help for long compiles +RETRY_SLEEP = 5 +TIMEOUT = 90 * 60 # 90 minutes + +# ===================================================== +# WORKER CONFIG (CPU-BASED) +# ===================================================== + +MAX_WORKERS = 8 + + +# ===================================================== +# DISCOVERY +# ===================================================== + + +def _discover_onnx_jobs(base_onnx_dir: str): + # agent: defer discovery to runtime and require explicit export path. + onnx_jobs = [] + base_dir_path = Path(base_onnx_dir) + layerwise_dir = base_dir_path / "onnx_layerwise_tmp" + if layerwise_dir.is_dir(): + scan_dir = layerwise_dir + elif base_dir_path.is_dir(): + scan_dir = base_dir_path + else: + raise RuntimeError(f"BASE_ONNX_DIR does not exist: {base_onnx_dir}") + + layer_dir_pat = re.compile(r"^layer_(\d+)_(\d+)$") + for layer_dir in sorted(scan_dir.iterdir()): + if not layer_dir.is_dir(): + continue + + m = layer_dir_pat.match(layer_dir.name) + if not m: + continue + + layer_start = int(m.group(1)) + layer_end = int(m.group(2)) + if layer_end <= layer_start: + continue + + layer_indices = [str(i) for i in range(layer_start, layer_end)] + layer_window = (layer_start, layer_end) + + for f in layer_dir.iterdir(): + if f.name.startswith("DeepseekV3ForCausalLM_layer_tmp_") and f.suffix == ".onnx": + # device_group fixed to single device "0" + onnx_jobs.append((f, layer_dir, layer_window, layer_indices, "0")) + + if not onnx_jobs: + raise RuntimeError(f"No valid ONNX files found under: {scan_dir}") + + return onnx_jobs + + +# ===================================================== +# CUSTOM IO YAML WRITER +# ===================================================== + + +def write_custom_io_yaml(path: Path, indices): + with open(path, "w") as fp: + # agent: write cache entries for all layers in each discovered window. + for idx in indices: + fp.write(f" - IOName: k_pe.{idx}\n") + fp.write(" Precision: mxint8\n\n") + fp.write(f" - IOName: compressed_kv.{idx}\n") + fp.write(" Precision: mxint8\n\n") + + for idx in indices: + fp.write(f" - IOName: k_pe.{idx}_RetainedState\n") + fp.write(" Precision: mxint8\n\n") + fp.write(f" - IOName: compressed_kv.{idx}_RetainedState\n") + fp.write(" Precision: mxint8\n\n") + + +# ===================================================== +# COMPILE FUNCTION +# ===================================================== + + +def compile_one(job): + onnx_path, layer_dir, layer_window, layer_indices, device_group = job + + layer_tag = onnx_path.stem.replace("DeepseekV3ForCausalLM_layer_tmp_", "") + + qpc_dir = layer_dir / f"qpc_{layer_tag}" + log_file = layer_dir / f"qpc_{layer_tag}.log" + qpc_dir.mkdir(parents=True, exist_ok=True) + + custom_io_yaml = layer_dir / "custom_io_fp16.yaml" + if not custom_io_yaml.exists(): + write_custom_io_yaml(custom_io_yaml, layer_indices) + + cmd = [ + "python", + "-m", + "QEfficient.cloud.compile", + "--onnx_path", + str(onnx_path), + "--qpc-path", + str(qpc_dir), + "--batch_size", + "1", + "--prompt_len", + "1", + "--ctx_len", + "128", + "--mxfp6", + "mxint8_kv_cache", + "--num_cores", + "16", + "--device_group", + device_group, + "--mos", + "1", + "--aic_enable_depth_first", + f"-custom-IO-list-file={custom_io_yaml}", + ] + + total_start = time.time() + last_status = "FAILED" + + for attempt in range(1, MAX_RETRIES + 1): + print( + f"[START ] layer {layer_window[0]}_{layer_window[1]} " + f"device {device_group} (attempt {attempt}/{MAX_RETRIES})" + ) + + proc = None + try: + with open(log_file, "a") as lf: + lf.write(f"\n===== ATTEMPT {attempt} =====\n") + proc = subprocess.Popen( + cmd, + stdout=lf, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + proc.wait(timeout=TIMEOUT) + + if proc.returncode == 0: + last_status = "OK" + break + else: + last_status = f"FAILED(rc={proc.returncode})" + + except subprocess.TimeoutExpired: + last_status = "TIMEOUT" + if proc: + os.killpg(proc.pid, signal.SIGTERM) + break # do not retry timeouts + + except KeyboardInterrupt: + if proc: + os.killpg(proc.pid, signal.SIGTERM) + raise + + except Exception as e: + last_status = f"ERROR({e})" + if proc: + os.killpg(proc.pid, signal.SIGTERM) + break + + time.sleep(RETRY_SLEEP) + + total_elapsed = time.time() - total_start + + print(f"[DONE ] layer {layer_window[0]}_{layer_window[1]} {last_status} | {total_elapsed:.1f}s") + + return layer_tag, last_status, total_elapsed + + +# ===================================================== +# MAIN +# ===================================================== + + +def run_compile_layerwise(base_onnx_dir: str): + # agent: path is expected to be export root and is normalized in run.py. + onnx_jobs = _discover_onnx_jobs(base_onnx_dir) + print(f"MAX_WORKERS set to : {MAX_WORKERS}") + print(f"Found {len(onnx_jobs)} ONNX files\n") + + start_time = time.time() + results = [] + interrupted = False + + try: + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = [executor.submit(compile_one, job) for job in onnx_jobs] + + for fut in as_completed(futures): + results.append(fut.result()) + + except KeyboardInterrupt: + interrupted = True + print("\n[INTERRUPT] KeyboardInterrupt received") + + finally: + total_time = time.time() - start_time + + success = sum(1 for _, s, _ in results if s == "OK") + failed = sum(1 for _, s, _ in results if s != "OK") + completed = len(results) + pending = len(onnx_jobs) - completed + + print("\n============================================") + print(f"TOTAL FILES : {len(onnx_jobs)}") + print(f"COMPLETED : {completed}") + print(f"SUCCESS : {success}") + print(f"FAILED : {failed}") + print(f"PENDING : {pending}") + print(f"TOTAL TIME : {total_time:.1f} seconds") + print(f"INTERRUPTED : {interrupted}") + print("============================================") + + +if __name__ == "__main__": + # agent: CLI now takes exported path instead of embedded machine-local constant. + parser = argparse.ArgumentParser(description="Compile layerwise ONNX windows into QPC artifacts.") + parser.add_argument("--base-onnx-dir", required=True, help="Export root containing onnx_layerwise_tmp/") + args = parser.parse_args() + run_compile_layerwise(args.base_onnx_dir) diff --git a/QEfficient/utils/inference.py b/QEfficient/utils/inference.py new file mode 100644 index 0000000000..d686fc29d4 --- /dev/null +++ b/QEfficient/utils/inference.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import argparse +import re +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +from transformers import AutoTokenizer + +from QEfficient.generation.cloud_infer import QAICInferenceSession + +LAYER_DIR_RE = re.compile(r"layer_(\d+)_(\d+)$") + + +def discover_qpc_paths(base_path: Path) -> List[Path]: + layer_dirs = [] + for child in base_path.iterdir(): + if not child.is_dir(): + continue + match = LAYER_DIR_RE.match(child.name) + if not match: + continue + layer_dirs.append((int(match.group(1)), int(match.group(2)), child)) + + if not layer_dirs: + raise FileNotFoundError(f"No layer directories found under: {base_path}") + + layer_dirs.sort(key=lambda x: (x[0], x[1])) + qpc_paths: List[Path] = [] + for _, _, layer_dir in layer_dirs: + candidates = sorted(p for p in layer_dir.glob("**/qpcs") if p.is_dir()) + if not candidates: + raise FileNotFoundError(f"No qpcs directory found in: {layer_dir}") + qpc_paths.append(candidates[0]) + return qpc_paths + + +def pick_token_input_name(session: QAICInferenceSession) -> Optional[str]: + if "input_ids" in session.input_names: + return "input_ids" + for name in session.input_names: + if "input_ids" in name: + return name + return None + + +def pick_hidden_input_name(session: QAICInferenceSession) -> Optional[str]: + for preferred in ("inputs_embeds", "input_embeds"): + if preferred in session.input_names: + return preferred + for name in session.input_names: + if name == "position_ids": + continue + if "compressed_kv" in name or "k_pe" in name: + continue + if "input_ids" in name: + continue + return name + return None + + +def pick_pos_input_name(session: QAICInferenceSession) -> Optional[str]: + if "position_ids" in session.input_names: + return "position_ids" + for name in session.input_names: + if "position" in name: + return name + return None + + +def pick_main_output_name(session: QAICInferenceSession) -> str: + candidates = [name for name in session.output_names] + if not candidates: + raise RuntimeError(f"No usable output name found for session outputs: {session.output_names}") + if "logits" in candidates: + return "logits" + return candidates[-1] + + +def output_placeholder(session: QAICInferenceSession, output_name: str) -> np.ndarray: + idx = session.binding_index_map[output_name] + binding = session.bindings[idx] + dtype = session.aic_to_np_dtype_mapping[binding.type] + shape = tuple(max(1, int(dim)) for dim in binding.dims) + return np.zeros(shape, dtype=dtype) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run layerwise QAIC prefill + decode from a base path.") + parser.add_argument("base_path", type=Path, help="Path to onnx_layerwise_tmp (contains layer_*/...)") + parser.add_argument("--model-name", default="moonshotai/Kimi-K2.5") + parser.add_argument("--prompt", default="Help") + parser.add_argument("--max-len", type=int, default=32) + parser.add_argument( + "--device-start", + type=int, + default=None, + help="Optional starting device id. If set, layer i uses device_start + i.", + ) + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained( + args.model_name, + padding_side="right", + trust_remote_code=True, + ) + prompt_ids = tokenizer(args.prompt, return_tensors="np", add_special_tokens=True)["input_ids"][0].tolist() + all_ids = list(prompt_ids) + + qpc_paths = discover_qpc_paths(args.base_path) + print(f"[LOAD] Found {len(qpc_paths)} layer sessions") + + sessions: List[Dict[str, object]] = [] + for i, qpc in enumerate(qpc_paths): + device_ids = [args.device_start + i] if args.device_start is not None else None + session = QAICInferenceSession(str(qpc), device_ids=device_ids) + session.skip_buffers( + [n for n in session.input_names + session.output_names if "compressed_kv" in n or "k_pe" in n] + ) + + out_name = pick_main_output_name(session) + session.set_buffers({out_name: output_placeholder(session, out_name)}) + + sessions.append( + { + "session": session, + "token_input": pick_token_input_name(session), + "hidden_input": pick_hidden_input_name(session), + "pos_input": pick_pos_input_name(session), + "out_name": out_name, + } + ) + print(f"[LOAD] layer {i}: {qpc} -> out={out_name}") + + if not sessions: + raise RuntimeError("No sessions loaded") + if sessions[0]["token_input"] is None: + raise RuntimeError(f"First layer has no token input. inputs={sessions[0]['session'].input_names}") + + logits = None + + # Prefill: pass each prompt token through all layers + for pos, token_id in enumerate(prompt_ids): + hidden = None + for i, info in enumerate(sessions): + session = info["session"] + run_inputs: Dict[str, np.ndarray] = {} + if i == 0: + run_inputs[info["token_input"]] = np.array([[token_id]], dtype=np.int64) + else: + if hidden is None: + raise RuntimeError("Missing hidden state while executing intermediate layer") + if info["hidden_input"] is None: + raise RuntimeError(f"Layer {i} has no hidden-state input. inputs={session.input_names}") + run_inputs[info["hidden_input"]] = hidden + + if info["pos_input"] is not None: + run_inputs[info["pos_input"]] = np.array([[pos]], dtype=np.int64) + + outputs = session.run(run_inputs) + hidden = outputs[info["out_name"]] + logits = hidden + + if logits is None: + raise RuntimeError("Prompt produced no logits") + + # Decode + generated_ids: List[int] = [] + while len(all_ids) < args.max_len: + next_token_id = int(np.argmax(logits, axis=-1)[0, 0]) + generated_ids.append(next_token_id) + all_ids.append(next_token_id) + + if tokenizer.eos_token_id is not None and next_token_id == tokenizer.eos_token_id: + break + + pos = len(all_ids) - 1 + hidden = None + for i, info in enumerate(sessions): + session = info["session"] + run_inputs: Dict[str, np.ndarray] = {} + if i == 0: + run_inputs[info["token_input"]] = np.array([[next_token_id]], dtype=np.int64) + else: + if hidden is None: + raise RuntimeError("Missing hidden state while decoding intermediate layer") + if info["hidden_input"] is None: + raise RuntimeError(f"Layer {i} has no hidden-state input. inputs={session.input_names}") + run_inputs[info["hidden_input"]] = hidden + + if info["pos_input"] is not None: + run_inputs[info["pos_input"]] = np.array([[pos]], dtype=np.int64) + + outputs = session.run(run_inputs) + hidden = outputs[info["out_name"]] + logits = hidden + + print("Generated token ids:") + print(generated_ids) + print("Generated text:") + print(tokenizer.decode(generated_ids, skip_special_tokens=True)) + + +if __name__ == "__main__": + main() diff --git a/QEfficient/utils/layerwise_pipeline.py b/QEfficient/utils/layerwise_pipeline.py new file mode 100644 index 0000000000..4b59b25466 --- /dev/null +++ b/QEfficient/utils/layerwise_pipeline.py @@ -0,0 +1,546 @@ +#!/usr/bin/env python3 +import argparse +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple + +import onnx +import onnx_ir +from onnx import external_data_helper + +from QEfficient.base.onnx_transforms import CustomOpTransform, RemovePrefix + +# ============================================================ +# PREFIX/DELETION CONFIG (defaults preserved) +# ============================================================ +SAVE_WORKERS = 8 +DELETE_WORKERS = 8 +DELETE_SUFFIXES = ("onnx.data",) +_delete_pool = ThreadPoolExecutor(max_workers=DELETE_WORKERS) + + +# agent change start: generalized layer-window discovery +def _discover_layer_windows(exported_path: str, start_layer: int = 0) -> List[Tuple[int, int]]: + base_path = f"{exported_path}/onnx_layerwise_tmp" + if not os.path.isdir(base_path): + raise FileNotFoundError(f"Missing layerwise directory: {base_path}") + + windows: List[Tuple[int, int]] = [] + pat = re.compile(r"^layer_(\d+)_(\d+)$") + for entry in os.scandir(base_path): + if not entry.is_dir(): + continue + m = pat.match(entry.name) + if not m: + continue + layer_start, layer_end = int(m.group(1)), int(m.group(2)) + if layer_end <= layer_start: + continue + if layer_start < start_layer: + continue + windows.append((layer_start, layer_end)) + + windows.sort(key=lambda x: x[0]) + if not windows: + raise RuntimeError(f"No layer windows found in {base_path}. Expected directories like layer__.") + return windows + + +def _window_paths(exported_path: str, layer_start: int, layer_end: int) -> Tuple[str, str, str]: + base_dir = f"{exported_path}/onnx_layerwise_tmp/layer_{layer_start}_{layer_end}" + onnx_tmp = f"{base_dir}/DeepseekV3ForCausalLM_layer_tmp_{layer_start}_{layer_end}.onnx" + split_graph = f"{base_dir}/split_graph.onnx" + return base_dir, onnx_tmp, split_graph + + +# agent change end: generalized layer-window discovery + + +# ============================================================ +# STAGE 1: SPLITTING +# ============================================================ +def split_layer_graph( + shard_idx: int, + total_shards: int, + exported_path: str, + layer_start: int, + layer_end: int, +) -> bool: + base_dir, onnx_path, out_path = _window_paths(exported_path, layer_start, layer_end) + + if not os.path.exists(onnx_path): + print(f"[SKIP] ONNX not found: {onnx_path}") + return False + + model = onnx.load(onnx_path, load_external_data=False) + + decoder_input = None + decoder_output = None + for node in model.graph.node: + if "DecoderLayer" in node.name: + decoder_input = list(node.input) + decoder_output = list(node.output) + break + + if decoder_input is None or decoder_output is None: + raise RuntimeError(f"DecoderLayer not found in layer window {layer_start}_{layer_end}") + + model_ir = onnx_ir.load(onnx_path) + + # agent change start: generalized shard io selection (works for 1-layer and multi-layer windows) + graph_inputs = [v.name for v in model.graph.input] + graph_outputs = [v.name for v in model.graph.output] + + if layer_start == 0: + preferred_inputs = ["input_ids", "position_ids"] + else: + preferred_inputs = ["inputs_embeds", "position_ids"] + + cache_inputs = sorted([n for n in graph_inputs if n.startswith("compressed_kv.") or n.startswith("k_pe.")]) + input_names = [n for n in preferred_inputs if n in graph_inputs] + cache_inputs + + output_names = list(graph_outputs) + if shard_idx != total_shards - 1 and "position_ids" in graph_inputs and "position_ids" not in output_names: + output_names.append("position_ids") + # agent change end: generalized shard io selection (works for 1-layer and multi-layer windows) + + model_ir.graph = onnx_ir.convenience.extract( + model_ir.graph, + input_names, + output_names, + ) + + onnx_ir.save(model_ir, out_path) + onnx.load(out_path, load_external_data=False) + + print(f"[DONE] Layer window {layer_start}_{layer_end}: saved split graph -> {out_path}") + return True + + +def run_split_pipeline(exported_path: str, num_layers: int = 61, start_layer: int = 0) -> None: + windows = _discover_layer_windows(exported_path, start_layer=start_layer) + print( + f"[START] split pipeline | exported_path={exported_path}, " + f"start_layer={start_layer}, discovered_shards={len(windows)}" + ) + for shard_idx, (layer_start, layer_end) in enumerate(windows): + print(f"[PROCESS] Layer window {layer_start}_{layer_end}") + split_layer_graph(shard_idx, len(windows), exported_path, layer_start, layer_end) + print("[DONE] split pipeline complete") + + +# ============================================================ +# STAGE 2: PREFIX + DELETION +# ============================================================ +def async_delete_files(paths: List[str]) -> None: + def _delete(p): + try: + os.remove(p) + except FileNotFoundError: + pass + except Exception as e: + print(f"[delete] failed {p}: {e}") + + for p in paths: + _delete_pool.submit(_delete, p) + + +def collect_chunk_deletable_files(exported_path: str, layer_windows: List[Tuple[int, int]]) -> List[str]: + files = [] + for layer_start, layer_end in layer_windows: + layer_dir = f"{exported_path}/onnx_layerwise_tmp/layer_{layer_start}_{layer_end}" + if not os.path.isdir(layer_dir): + continue + for entry in os.scandir(layer_dir): + if entry.is_file() and entry.name.endswith(DELETE_SUFFIXES): + files.append(entry.path) + return files + + +def rewrite_tensors_with_prefix( + model: onnx.ModelProto, + prefix: str, + func_attr_tens, + size_threshold: int = 1024, + file_chunk_size: int = 10 * 2**30, +) -> None: + size = 0 + file_num = 0 + + for tensor in external_data_helper._get_all_tensors(model): + if tensor.HasField("raw_data") and tensor.name != "int64_2" and tensor.name not in func_attr_tens: + tsize = len(tensor.raw_data) + if tsize > size_threshold: + if size + tsize > file_chunk_size: + file_num += 1 + size = tsize + else: + size += tsize + + external_data_helper.set_external_data(tensor, f"{prefix}_{file_num}.onnx.data") + + +def saving_prefix_file( + location: str, layer_start: int, layer_end: int, exported_path: str, final_data_dir: str +) -> None: + model = onnx.load(location, load_external_data=False) + + model_pref = onnx.compose.add_prefix(model, f"layer_{layer_start}/", rename_functions=False) + + base_dir = f"{exported_path}/onnx_layerwise_tmp/layer_{layer_start}_{layer_end}" + external_data_helper.load_external_data_for_model(model_pref, base_dir) + + func_attr_tens = set() + if model_pref.functions: + func_attr_tens = { + v.name for v in external_data_helper._get_attribute_tensors_from_graph(model_pref.functions[0]) + } + + rewrite_tensors_with_prefix( + model_pref, + prefix=f"layer_{layer_start}", + func_attr_tens=func_attr_tens, + ) + + out_dir = f"{exported_path}/{final_data_dir}" + os.makedirs(out_dir, exist_ok=True) + onnx.save(model_pref, f"{out_dir}/pref_{layer_start}.onnx") + + +def run_saving_prefix(layer_start: int, layer_end: int, exported_path: str, final_data_dir: str) -> int: + _, _, loc = _window_paths(exported_path, layer_start, layer_end) + saving_prefix_file(loc, layer_start, layer_end, exported_path, final_data_dir) + return layer_start + + +def run_prefix_pipeline( + exported_path: str, + num_layers: int = 61, + chunk_size: int = 8, + final_data_dir: str = "final_data", +) -> None: + windows = _discover_layer_windows(exported_path, start_layer=0) + print( + f"[START] prefix+deletion pipeline | exported_path={exported_path}, " + f"discovered_shards={len(windows)}, chunk_size={chunk_size}" + ) + + for chunk_start in range(0, len(windows), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(windows)) + chunk_windows = windows[chunk_start:chunk_end] + + print(f"\\n[Chunk] {chunk_start} -> {chunk_end - 1} (window count)") + t0 = time.time() + + with ThreadPoolExecutor(max_workers=SAVE_WORKERS) as pool: + futures = [ + pool.submit(run_saving_prefix, layer_start, layer_end, exported_path, final_data_dir) + for (layer_start, layer_end) in chunk_windows + ] + for f in as_completed(futures): + f.result() + + print(f"[Chunk] saved in {time.time() - t0:.2f}s") + + # deletables = collect_chunk_deletable_files(exported_path, chunk_windows) + # async_delete_files(deletables) + # print(f"[Chunk] scheduled deletion of {len(deletables)} files") + + print("[DONE] prefix+deletion pipeline complete") + + +# ============================================================ +# STAGE 3: MERGING +# ============================================================ +def compare_onnx_func(func1: onnx.FunctionProto, func2: onnx.FunctionProto): + if ( + len(func1.input) != len(func2.input) + or len(func1.output) != len(func2.output) + or len(func1.node) != len(func2.node) + ): + return False + + for i in range(len(func1.node)): + node1 = func1.node[i] + node2 = func2.node[i] + + if len(node1.input) != len(node2.input): + return False + for j in range(len(node1.input)): + if node1.input[j] in func1.input: + idx = list(func1.input).index(node1.input[j]) + if node2.input[j] not in func2.input or list(func2.input).index(node2.input[j]) != idx: + return False + elif node1.input[j] != node2.input[j]: + if node1.input[j] in func1.output: + idx = list(func1.output).index(node1.input[j]) + if node2.input[j] not in func2.output or list(func2.output).index(node2.input[j]) != idx: + return False + else: + return False + + if node1.op_type != node2.op_type: + return False + if len(node1.attribute) != len(node2.attribute): + return False + for j in range(len(node1.attribute)): + if node1.attribute[j] != node2.attribute[j]: + return False + + if len(node1.output) != len(node2.output): + return False + for j in range(len(node1.output)): + if node1.output[j] in func1.output: + idx = list(func1.output).index(node1.output[j]) + if node2.output[j] not in func2.output or list(func2.output).index(node2.output[j]) != idx: + return False + else: + if node1.output[j] != node2.output[j]: + return False + + return True + + +def merge_models(m1, m2, io_map): + def is_decoder(name: str) -> bool: + return "DecoderLayer" in name + + def copy_with_name(func: onnx.FunctionProto, new_name: str) -> onnx.FunctionProto: + f = onnx.FunctionProto() + f.CopyFrom(func) + f.name = new_name + return f + + def update_node_calls(graph: onnx.GraphProto, old_name: str, new_name: str): + if old_name == new_name: + return + for node in graph.node: + if node.op_type == old_name: + node.op_type = new_name + + graph = onnx.compose.merge_graphs(m1.graph, m2.graph, io_map) + model = onnx.helper.make_model_gen_version( + graph, + producer_name="QEfficient", + producer_version="1.21", + ir_version=10, + opset_imports=m1.opset_import, + ) + + props = {} + for p in m1.metadata_props: + props[p.key] = p.value + for p in m2.metadata_props: + if p.key in props and props[p.key] != p.value: + raise ValueError( + "Can't merge models with different values for the same model metadata property." + f" Found: property = {p.key}, with values {props[p.key]} and {p.value}." + ) + props[p.key] = p.value + onnx.helper.set_model_props(model, props) + + m1_funcs = [f.name for f in m1.functions] + m2_funcs = [f.name for f in m2.functions] + decoder_variants = {} + + def assign_decoder_variant(base_name: str, func: onnx.FunctionProto, src_graph: onnx.GraphProto) -> str: + variants = decoder_variants.setdefault(base_name, []) + + for existing_func, assigned_name in variants: + if compare_onnx_func(func, existing_func): + return assigned_name + + assigned = base_name if not variants else f"{base_name}__v{len(variants) + 1}" + variants.append((func, assigned)) + if assigned != base_name: + update_node_calls(src_graph, base_name, assigned) + return assigned + + final_funcs = {} + all_names = set(m1_funcs + m2_funcs) + + for name in all_names: + in_m1 = name in m1_funcs + in_m2 = name in m2_funcs + + if in_m1 and in_m2: + func1 = m1.functions[m1_funcs.index(name)] + func2 = m2.functions[m2_funcs.index(name)] + + if compare_onnx_func(func1, func2): + final_funcs[(func1.domain, func1.name)] = func1 + else: + if is_decoder(name): + name1 = assign_decoder_variant(name, func1, m1.graph) + name2 = assign_decoder_variant(name, func2, m2.graph) + + f1 = func1 if func1.name == name1 else copy_with_name(func1, name1) + f2 = func2 if func2.name == name2 else copy_with_name(func2, name2) + final_funcs[(f1.domain, f1.name)] = f1 + final_funcs[(f2.domain, f2.name)] = f2 + else: + raise ValueError(f"Function '{name}' differs between models and is not a DecoderLayer.") + elif in_m1: + f = m1.functions[m1_funcs.index(name)] + final_funcs[(f.domain, f.name)] = f + elif in_m2: + f = m2.functions[m2_funcs.index(name)] + final_funcs[(f.domain, f.name)] = f + else: + raise ValueError("Function not found") + + graph2 = onnx.compose.merge_graphs(m1.graph, m2.graph, io_map) + model.graph.CopyFrom(graph2) + + for (domain, name), f in final_funcs.items(): + if f.name != name: + f = copy_with_name(f, name) + model.functions.MergeFrom([f]) + + return model + + +def run_merge_pipeline(exported_path: str, num_layers: int = 61, final_data_dir: str = "final_data") -> str: + windows = _discover_layer_windows(exported_path, start_layer=0) + if len(windows) < 1: + raise ValueError("Need at least one discovered shard to merge") + + base_dir = f"{exported_path}/{final_data_dir}" + start = time.time() + print( + f"[START] merge pipeline | exported_path={exported_path}, " + f"discovered_shards={len(windows)}, final_data_dir={final_data_dir}" + ) + + # agent change start: generalized merge over discovered shard starts + shard_starts = [layer_start for (layer_start, _) in windows] + first_start = shard_starts[0] + last_start = shard_starts[-1] + + if len(shard_starts) == 1: + only_model = f"{base_dir}/pref_{first_start}.onnx" + if not os.path.exists(only_model): + raise FileNotFoundError(f"Missing input model: {only_model}") + print(f"[DONE] merge pipeline skipped (single layer): {only_model}") + return only_model + + for idx in range(len(shard_starts) - 1): + left = shard_starts[len(shard_starts) - idx - 2] + right = shard_starts[len(shard_starts) - idx - 1] + + m1_path = f"{base_dir}/pref_{left}.onnx" + m2_path = f"{base_dir}/pref_{right}.onnx" if idx == 0 else f"{base_dir}/merged_{right}-{last_start}.onnx" + + if not os.path.exists(m1_path): + raise FileNotFoundError(f"Missing input model: {m1_path}") + if not os.path.exists(m2_path): + raise FileNotFoundError(f"Missing input model: {m2_path}") + + print(f"[MERGE] {left}-{last_start}") + m1_pref = onnx.load(m1_path, load_external_data=False) + m2_pref = onnx.load(m2_path, load_external_data=False) + + decoder_nodes = [n for n in m1_pref.graph.node if "DecoderLayer" in n.name] + if not decoder_nodes: + raise RuntimeError(f"DecoderLayer node not found in {m1_path}") + if len(decoder_nodes) > 1: + decoder_output = list(decoder_nodes[1].output) + else: + decoder_output = list(decoder_nodes[0].output) + decoder_output = list(decoder_nodes[0].output) + merged_model = merge_models( + m1_pref, + m2_pref, + io_map=[ + (f"{decoder_output[2]}", f"layer_{right}/inputs_embeds"), + (f"layer_{left}/position_ids", f"layer_{right}/position_ids"), + ], + ) + + if idx == len(shard_starts) - 2: + CustomOpTransform.apply(merged_model) + + out_path = f"{base_dir}/merged_{left}-{last_start}.onnx" + onnx.save(merged_model, out_path) + print(f"[SAVED] {out_path}") + + final_path = f"{base_dir}/merged_{first_start}-{last_start}.onnx" + model = onnx.load(final_path, load_external_data=False) + RemovePrefix.apply(model) + onnx.save(model, final_path) + print(f"[DONE] merge pipeline complete in {time.time() - start:.2f}s") + return final_path + + +# ============================================================ +# ONE-SHOT ENTRY +# ============================================================ +def run_sequential_pipeline( + exported_path: str, + num_layers: int = 61, + start_layer: int = 0, + chunk_size: int = 8, + final_data_dir: str = "final_data", +) -> str: + print("\\n=== Stage 1/3: Splitting ===") + run_split_pipeline( + exported_path=exported_path, + num_layers=num_layers, + start_layer=start_layer, + ) + + print("\\n=== Stage 2/3: Prefix + Deletion ===") + run_prefix_pipeline( + exported_path=exported_path, + num_layers=num_layers, + chunk_size=chunk_size, + final_data_dir=final_data_dir, + ) + + print("\\n=== Stage 3/3: Merging ===") + final_path = run_merge_pipeline( + exported_path=exported_path, + num_layers=num_layers, + final_data_dir=final_data_dir, + ) + + print(f"\\n[PIPELINE DONE] Final merged model: {final_path}") + return final_path + + +def layerwise_pipeline( + exported_path: str, + num_layers: int = 61, + start_layer: int = 0, + chunk_size: int = 8, + final_data_dir: str = "final_data", +) -> str: + return run_sequential_pipeline( + exported_path=exported_path, + num_layers=num_layers, + start_layer=start_layer, + chunk_size=chunk_size, + final_data_dir=final_data_dir, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="All-in-one layer-wise ONNX split -> prefix/deletion -> merge pipeline." + ) + parser.add_argument("--exported_path", required=True, help="Base export path") + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--start-layer", type=int, default=0) + parser.add_argument("--chunk-size", type=int, default=8) + parser.add_argument("--final-data-dir", default="final_data") + args = parser.parse_args() + + run_sequential_pipeline( + exported_path=args.exported_path, + num_layers=args.num_layers, + start_layer=args.start_layer, + chunk_size=args.chunk_size, + final_data_dir=args.final_data_dir, + ) diff --git a/run.py b/run.py new file mode 100644 index 0000000000..225ea92ea6 --- /dev/null +++ b/run.py @@ -0,0 +1,227 @@ +import copy +import functools +import json +import tempfile +from pathlib import Path + +import torch +import transformers +from transformers import AutoConfig, AutoTokenizer +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +import QEfficient +from QEfficient import QEFFAutoModelForCausalLM + +MODEL_PATH = Path( + "/home/huggingface_hub/models--moonshotai--Kimi-K2.5/snapshots/54383e83fa343a1331754112fb9e3410c55efa2f" +) + +TS = 1 +enable_mla = True +mla_absorption = {"cache_compressed": True, "absorption": True, "online": False} +prefill_seq_len = 1 +ctx_len = 128 +qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication + +EXPORT_START = 1 +EXPORT_END = 3 +LAYERWISE_MODE = "pipeline" + + +def _ensure_pretrained_window_attrs(): + if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): + transformers.modeling_utils.PreTrainedModel._start = 0 + if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): + transformers.modeling_utils.PreTrainedModel._end = 0 + + +def _null_outside_window_layers(model): + start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) + end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) + + layers = getattr(getattr(model, "model", None), "layers", None) + if layers is None: + return + + print(f"{start} to {end}") + for idx, _ in enumerate(layers): + if idx < start or idx >= end: + layers[idx] = None + + +def _install_window_patch(model_cls): + if getattr(model_cls, "_window_patch_installed", False): + return + + original_init = model_cls.__init__ + + @functools.wraps(original_init) + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + _null_outside_window_layers(self) + + model_cls.__init__ = patched_init + model_cls._window_patch_installed = True + + +def load_text_only_kimi(model_path: Path, num_hidden_layers: int): + _ensure_pretrained_window_attrs() + kimi_config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) + + # Kimi K2.5 is multimodal, so we load only the text stack config. + text_config = copy.deepcopy(kimi_config.text_config) + + deepseek_cls = get_class_from_dynamic_module("modeling_deepseek.DeepseekV3ForCausalLM", str(model_path)) + _install_window_patch(deepseek_cls) + + checkpoint_index = json.loads((model_path / "model.safetensors.index.json").read_text()) + weight_map = checkpoint_index["weight_map"] + + allowed_prefixes = [ + "language_model.model.embed_tokens.", + "language_model.model.norm.", + "language_model.lm_head.", + ] + layer_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) + layer_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) + allowed_prefixes.extend( + [f"language_model.model.layers.{layer_idx}." for layer_idx in range(layer_start, layer_end)] + ) + + required_shards = sorted( + { + shard_name + for checkpoint_key, shard_name in weight_map.items() + if any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes) + } + ) + filtered_weight_map = { + checkpoint_key: shard_name + for checkpoint_key, shard_name in weight_map.items() + if any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes) + } + if not filtered_weight_map: + raise RuntimeError("No text-only weights were selected from the Kimi K2.5 checkpoint.") + + with tempfile.TemporaryDirectory() as tmpdir: + temp_model_path = Path(tmpdir) + (temp_model_path / "config.json").write_text(text_config.to_json_string(use_diff=False)) + (temp_model_path / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": { + "total_size": sum((model_path / shard_name).stat().st_size for shard_name in required_shards) + }, + "weight_map": filtered_weight_map, + } + ) + ) + for shard_name in required_shards: + (temp_model_path / shard_name).symlink_to(model_path / shard_name) + + # We are loading a task checkpoint into the base text model, so disable the + # base/task prefix heuristic and let `key_mapping` strip `language_model.`. + original_base_model_prefix = deepseek_cls.base_model_prefix + deepseek_cls.base_model_prefix = "" + try: + model, loading_info = deepseek_cls.from_pretrained( + str(temp_model_path), + config=text_config, + local_files_only=True, + key_mapping={r"^language_model\.": ""}, + output_loading_info=True, + ) + finally: + deepseek_cls.base_model_prefix = original_base_model_prefix + + unexpected_keys = loading_info["unexpected_keys"] + missing_keys = loading_info["missing_keys"] + mismatched_keys = loading_info["mismatched_keys"] + if unexpected_keys or missing_keys or mismatched_keys: + raise RuntimeError( + "Failed to load the text-only Kimi K2.5 checkpoint slice cleanly. " + f"missing={missing_keys}, unexpected={unexpected_keys}, mismatched={mismatched_keys}" + ) + + model.eval() + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + return model, tokenizer + + +def _build_layer_windows(total_layers: int, start: int, end: int): + if not (0 <= start < end <= total_layers): + raise ValueError( + f"Invalid export window start={start}, end={end} for total_layers={total_layers}. " + "Expected: 0 <= start < end <= total_layers." + ) + + windows = [] + if start > 0: + windows.append((0, start)) + + step = end - start + current = start + while current < total_layers: + current_end = min(current + step, total_layers) + windows.append((current, current_end)) + current = current_end + + return windows + + +def _resolve_export_root(onnx_path: Path) -> Path: + parts = list(onnx_path.parts) + if "onnx_layerwise_tmp" in parts: + marker_idx = parts.index("onnx_layerwise_tmp") + return Path(*parts[:marker_idx]) + return onnx_path.parent + + +def main(): + _ensure_pretrained_window_attrs() + text_config = AutoConfig.from_pretrained(str(MODEL_PATH), trust_remote_code=True).text_config + total_layers = getattr(text_config, "num_hidden_layers", None) + if total_layers is None: + raise ValueError("Could not resolve `num_hidden_layers` from text_config.") + windows = _build_layer_windows(total_layers=total_layers, start=EXPORT_START, end=EXPORT_END) + first_onnx_path = None + for start, end in windows: + transformers.modeling_utils.PreTrainedModel._start = start + transformers.modeling_utils.PreTrainedModel._end = end + transformers.modeling_utils.PreTrainedModel._total_layers = total_layers + QEfficient.transformers.models.deepseek_v3.modeling_deepseek.QEffDeepseekV3Model._start = start + QEfficient.transformers.models.deepseek_v3.modeling_deepseek.QEffDeepseekV3Model._end = end + QEfficient.transformers.models.deepseek_v3.modeling_deepseek.QEffDeepseekV3Model._total_layers = total_layers + QEfficient.base.modeling_qeff.QEFFBaseModel._start = start + QEfficient.base.modeling_qeff.QEFFBaseModel._end = end + QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = total_layers + model, tokenizer = load_text_only_kimi(MODEL_PATH, num_hidden_layers=end - start) + qeff_model = QEFFAutoModelForCausalLM( + model, num_kv_heads_repeat=1, qaic_config=qaic_config, torch_dtype=torch.float16 + ) + onnx_path = qeff_model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=TS, + num_cores=16, + qaic_config=qaic_config, + use_onnx_subfunctions=True, + ) + if first_onnx_path is None: + first_onnx_path = Path(onnx_path) + + if first_onnx_path is None: + raise RuntimeError("No ONNX path produced during compilation.") + export_root = _resolve_export_root(first_onnx_path) + + if LAYERWISE_MODE == "single_QPC": + QEfficient.utils.compile_layerwise(str(export_root)) + QEfficient.utils.inference(str(export_root)) + else: + QEfficient.utils.layerwise_pipeline(str(export_root)) + + +if __name__ == "__main__": + main() diff --git a/tests/transformers/models/audio_models/test_audio_embedding_models.py b/tests/transformers/models/audio_models/test_audio_embedding_models.py index 64dc06a595..82c613e557 100644 --- a/tests/transformers/models/audio_models/test_audio_embedding_models.py +++ b/tests/transformers/models/audio_models/test_audio_embedding_models.py @@ -139,7 +139,6 @@ def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config: Optional[str] = None, compare_results: Optional[bool] = False, ): - replace_transformers_quantizers() model_config = {"model_name": model_name} model_config["n_layer"] = n_layer @@ -200,7 +199,6 @@ def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, compare_results=True, manual_cleanup=manual_cleanup, num_devices=4 @@ -211,7 +209,6 @@ def test_full_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4, manual_cleanup=manual_cleanup) diff --git a/tests/transformers/models/audio_models/test_speech_seq2seq_models.py b/tests/transformers/models/audio_models/test_speech_seq2seq_models.py index 6509d02fe7..0c6fb29087 100644 --- a/tests/transformers/models/audio_models/test_speech_seq2seq_models.py +++ b/tests/transformers/models/audio_models/test_speech_seq2seq_models.py @@ -374,7 +374,6 @@ def check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, compare_results=True, manual_cleanup=manual_cleanup, num_devices=4 @@ -385,7 +384,6 @@ def test_full_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4, manual_cleanup=manual_cleanup) diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index cc2d074a08..f878acbe73 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -57,7 +57,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( retain_full_kv: Optional[bool] = None, compare_results: bool = False, ): - torch.manual_seed(42) replace_transformers_quantizers() model_hf = load_hf_causal_lm_model(model_name, num_hidden_layers=n_layer, config=config) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py index 4bf067e7c4..0568939cd2 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py @@ -31,7 +31,6 @@ @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -77,7 +76,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -123,7 +121,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -178,7 +175,6 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -244,7 +240,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -310,7 +305,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py index 8dbb0915b8..8c61cdc98d 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py @@ -33,7 +33,6 @@ @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - if model_name in ModelConfig.FULL_MODEL_TESTS_TO_SKIP: pytest.skip(f"Skipping full model test for {model_name} due to resource constraints.") check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -55,7 +54,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup) @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, @@ -89,7 +87,6 @@ def test_full_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_few_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -104,7 +101,6 @@ def test_few_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_dummy_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py b/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py index b6641d7951..f5f2384e67 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py @@ -32,7 +32,6 @@ @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") @@ -52,7 +51,6 @@ def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_ful @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -71,7 +69,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") @@ -97,7 +94,6 @@ def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_fu @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -117,7 +113,6 @@ def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_ @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -137,7 +132,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_f @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") diff --git a/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py b/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py index 0b488a5037..9d02acbd29 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py @@ -32,7 +32,6 @@ @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, @@ -46,7 +45,6 @@ def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanu @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -61,7 +59,6 @@ def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, @@ -81,7 +78,6 @@ def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_clean @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, @@ -96,7 +92,6 @@ def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cle @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -112,7 +107,6 @@ def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_clea @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py b/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py index 2ff366ece2..af8c3b70f0 100644 --- a/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py +++ b/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py @@ -127,7 +127,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_causal_lm_pytorch_vs_kv_vs_ai100( model_name=model_name, torch_dtype=torch.float16, manual_cleanup=manual_cleanup @@ -139,7 +138,6 @@ def test_full_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ai100( @@ -152,7 +150,6 @@ def test_few_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_dummy_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( diff --git a/tests/transformers/models/image_text_to_text/test_custom_dtype.py b/tests/transformers/models/image_text_to_text/test_custom_dtype.py index 95f62f1ac9..f291c5d12c 100644 --- a/tests/transformers/models/image_text_to_text/test_custom_dtype.py +++ b/tests/transformers/models/image_text_to_text/test_custom_dtype.py @@ -41,7 +41,6 @@ def test_full_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( model_name, kv_offload, torch_dtype, manual_cleanup ): - if model_name in ModelConfig.SKIPPED_MODELS: pytest.skip("Test skipped for this model due to some issues.") if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: @@ -65,7 +64,6 @@ def test_full_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( def test_few_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( model_name, kv_offload, torch_dtype, manual_cleanup ): - if model_name in ModelConfig.SKIPPED_MODELS: pytest.skip("Test skipped for this model due to some issues.") if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: diff --git a/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py b/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py index 5c58508385..b3f42e1b0c 100644 --- a/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py +++ b/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py @@ -64,7 +64,6 @@ def check_blockedKV_onnx_function_count_with_subfunction( @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_full_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). check_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup=manual_cleanup) @@ -73,7 +72,6 @@ def test_full_blockedKV_onnx_function_count_with_subfunction(model_name, manual_ @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_few_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). n_layer = get_custom_n_layers(model_name) @@ -84,7 +82,6 @@ def test_few_blockedKV_onnx_function_count_with_subfunction(model_name, manual_c @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_dummy_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/subfunction/test_subfunction_vlm.py b/tests/transformers/subfunction/test_subfunction_vlm.py index baf690e638..39e2c6d0ac 100644 --- a/tests/transformers/subfunction/test_subfunction_vlm.py +++ b/tests/transformers/subfunction/test_subfunction_vlm.py @@ -50,7 +50,6 @@ def check_image_text_to_text_subfunction_core( num_devices: int = 1, config: Optional[AutoConfig] = None, ): - img_size = model_config_dict[model_name]["img_size"] img_url = model_config_dict[model_name]["img_url"] query = model_config_dict[model_name]["text_prompt"] @@ -117,7 +116,6 @@ def check_image_text_to_text_subfunction_core( @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_full_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) check_image_text_to_text_subfunction_core(model_name, kv_offload=kv_offload, manual_cleanup=manual_cleanup) @@ -127,7 +125,6 @@ def test_full_image_text_to_text_subfunction(model_name, kv_offload, manual_clea @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_few_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) check_image_text_to_text_subfunction_core( model_name, @@ -142,7 +139,6 @@ def test_few_image_text_to_text_subfunction(model_name, kv_offload, manual_clean @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_dummy_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) hf_config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, **model_config_dict[model_name].get("additional_params", {}) From 4c069c3dc71ad18357b07c6af2dd19dbf61b0c8c Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Tue, 28 Apr 2026 16:13:52 +0530 Subject: [PATCH 44/51] Made minor fix Signed-off-by: Abhishek Kumar Singh --- QEfficient/utils/__init__.py | 9 +++ .../{inference.py => inference_pipeline.py} | 59 +++++++++++++------ pyproject.toml | 1 + run.py | 2 +- 4 files changed, 51 insertions(+), 20 deletions(-) rename QEfficient/utils/{inference.py => inference_pipeline.py} (88%) diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index cfe17ac452..473d095381 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -38,7 +38,16 @@ require_value, to_named_specializations, ) +from QEfficient.utils.compile_layerwise import ( # noqa: F401 + run_compile_layerwise as compile_layerwise, +) from QEfficient.utils.hash_utils import ( # noqa: F401 create_export_hash, hash_dict_params, ) +from QEfficient.utils.inference_pipeline import ( # noqa: F401 + inference_pipeline, +) +from QEfficient.utils.layerwise_pipeline import ( # noqa: F401 + layerwise_pipeline, +) diff --git a/QEfficient/utils/inference.py b/QEfficient/utils/inference_pipeline.py similarity index 88% rename from QEfficient/utils/inference.py rename to QEfficient/utils/inference_pipeline.py index d686fc29d4..7c87d5afca 100644 --- a/QEfficient/utils/inference.py +++ b/QEfficient/utils/inference_pipeline.py @@ -86,34 +86,27 @@ def output_placeholder(session: QAICInferenceSession, output_name: str) -> np.nd return np.zeros(shape, dtype=dtype) -def main() -> None: - parser = argparse.ArgumentParser(description="Run layerwise QAIC prefill + decode from a base path.") - parser.add_argument("base_path", type=Path, help="Path to onnx_layerwise_tmp (contains layer_*/...)") - parser.add_argument("--model-name", default="moonshotai/Kimi-K2.5") - parser.add_argument("--prompt", default="Help") - parser.add_argument("--max-len", type=int, default=32) - parser.add_argument( - "--device-start", - type=int, - default=None, - help="Optional starting device id. If set, layer i uses device_start + i.", - ) - args = parser.parse_args() - +def inference_pipeline( + base_path: str | Path, + model_name: str = "moonshotai/Kimi-K2.5", + prompt: str = "Help", + max_len: int = 32, + device_start: Optional[int] = None, +) -> List[int]: tokenizer = AutoTokenizer.from_pretrained( - args.model_name, + model_name, padding_side="right", trust_remote_code=True, ) - prompt_ids = tokenizer(args.prompt, return_tensors="np", add_special_tokens=True)["input_ids"][0].tolist() + prompt_ids = tokenizer(prompt, return_tensors="np", add_special_tokens=True)["input_ids"][0].tolist() all_ids = list(prompt_ids) - qpc_paths = discover_qpc_paths(args.base_path) + qpc_paths = discover_qpc_paths(Path(base_path)) print(f"[LOAD] Found {len(qpc_paths)} layer sessions") sessions: List[Dict[str, object]] = [] for i, qpc in enumerate(qpc_paths): - device_ids = [args.device_start + i] if args.device_start is not None else None + device_ids = [device_start + i] if device_start is not None else None session = QAICInferenceSession(str(qpc), device_ids=device_ids) session.skip_buffers( [n for n in session.input_names + session.output_names if "compressed_kv" in n or "k_pe" in n] @@ -167,7 +160,7 @@ def main() -> None: # Decode generated_ids: List[int] = [] - while len(all_ids) < args.max_len: + while len(all_ids) < max_len: next_token_id = int(np.argmax(logits, axis=-1)[0, 0]) generated_ids.append(next_token_id) all_ids.append(next_token_id) @@ -200,6 +193,34 @@ def main() -> None: print(generated_ids) print("Generated text:") print(tokenizer.decode(generated_ids, skip_special_tokens=True)) + return generated_ids + + +def inference_pipelines(base_path: str | Path) -> List[int]: + # Backward-compatible wrapper used by some local scripts. + return inference_pipeline(base_path=base_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run layerwise QAIC prefill + decode from a base path.") + parser.add_argument("base_path", type=Path, help="Path to onnx_layerwise_tmp (contains layer_*/...)") + parser.add_argument("--model-name", default="moonshotai/Kimi-K2.5") + parser.add_argument("--prompt", default="Help") + parser.add_argument("--max-len", type=int, default=32) + parser.add_argument( + "--device-start", + type=int, + default=None, + help="Optional starting device id. If set, layer i uses device_start + i.", + ) + args = parser.parse_args() + inference_pipeline( + base_path=args.base_path, + model_name=args.model_name, + prompt=args.prompt, + max_len=args.max_len, + device_start=args.device_start, + ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 9a3a639381..2c50db5b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "fsspec==2023.6.0", "sentencepiece==0.2.0", "onnx==1.18.0", + "onnx_ir", "onnxruntime==1.22", "numpy==1.26.4", "protobuf==6.31.0", diff --git a/run.py b/run.py index 225ea92ea6..6e7f5899ac 100644 --- a/run.py +++ b/run.py @@ -218,7 +218,7 @@ def main(): if LAYERWISE_MODE == "single_QPC": QEfficient.utils.compile_layerwise(str(export_root)) - QEfficient.utils.inference(str(export_root)) + QEfficient.utils.inference_pipelines(str(export_root)) else: QEfficient.utils.layerwise_pipeline(str(export_root)) From bf1b9e3458d4f7840978c08673cbec79088ef080 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Tue, 28 Apr 2026 16:23:57 +0530 Subject: [PATCH 45/51] Update run.py Signed-off-by: Abhishek Kumar Singh --- run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run.py b/run.py index 6e7f5899ac..cf60591420 100644 --- a/run.py +++ b/run.py @@ -25,7 +25,7 @@ EXPORT_START = 1 EXPORT_END = 3 -LAYERWISE_MODE = "pipeline" +LAYERWISE_MODE = "single_qpc" def _ensure_pretrained_window_attrs(): @@ -216,7 +216,7 @@ def main(): raise RuntimeError("No ONNX path produced during compilation.") export_root = _resolve_export_root(first_onnx_path) - if LAYERWISE_MODE == "single_QPC": + if LAYERWISE_MODE == "multiple_qpc": QEfficient.utils.compile_layerwise(str(export_root)) QEfficient.utils.inference_pipelines(str(export_root)) else: From ab312d065eaaa725ad576abcb56dabc5d7f68319 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Tue, 28 Apr 2026 21:18:29 +0530 Subject: [PATCH 46/51] Update modeling_qeff.py Signed-off-by: Abhishek Kumar Singh --- QEfficient/base/modeling_qeff.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 4dad077d8c..0a72a6ffbe 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -416,9 +416,6 @@ def _export( model, transformed = onnx_transforms.apply(model, **transform_kwargs) onnx.save(model, layer_onnx_path_tmp) self.onnx_path = layer_onnx_path_tmp - import pdb - - pdb.set_trace() return layer_onnx_path_tmp def get_onnx_path( From 7cd93e366b4ea9995e3166f43233921d736e22c6 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Wed, 29 Apr 2026 03:04:16 +0530 Subject: [PATCH 47/51] Made minor fix Signed-off-by: Abhishek Kumar Singh --- QEfficient/utils/inference_pipeline.py | 2 +- run.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/utils/inference_pipeline.py b/QEfficient/utils/inference_pipeline.py index 7c87d5afca..b36ea9aa50 100644 --- a/QEfficient/utils/inference_pipeline.py +++ b/QEfficient/utils/inference_pipeline.py @@ -101,7 +101,7 @@ def inference_pipeline( prompt_ids = tokenizer(prompt, return_tensors="np", add_special_tokens=True)["input_ids"][0].tolist() all_ids = list(prompt_ids) - qpc_paths = discover_qpc_paths(Path(base_path)) + qpc_paths = discover_qpc_paths(Path(base_path + "/onnx_layerwise_tmp")) print(f"[LOAD] Found {len(qpc_paths)} layer sessions") sessions: List[Dict[str, object]] = [] diff --git a/run.py b/run.py index cf60591420..5f423de5df 100644 --- a/run.py +++ b/run.py @@ -218,7 +218,7 @@ def main(): if LAYERWISE_MODE == "multiple_qpc": QEfficient.utils.compile_layerwise(str(export_root)) - QEfficient.utils.inference_pipelines(str(export_root)) + QEfficient.utils.inference_pipeline(str(export_root)) else: QEfficient.utils.layerwise_pipeline(str(export_root)) From 067ce7d14f330a2aa4e8518117f7ffce66da5ac7 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Tue, 28 Apr 2026 18:28:36 +0530 Subject: [PATCH 48/51] Made minor fix Signed-off-by: Abhishek Kumar Singh --- QEfficient/utils/inference_pipeline.py | 2 +- dbg.log | 0 run.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 dbg.log diff --git a/QEfficient/utils/inference_pipeline.py b/QEfficient/utils/inference_pipeline.py index b36ea9aa50..925a429e00 100644 --- a/QEfficient/utils/inference_pipeline.py +++ b/QEfficient/utils/inference_pipeline.py @@ -203,7 +203,7 @@ def inference_pipelines(base_path: str | Path) -> List[int]: def main() -> None: parser = argparse.ArgumentParser(description="Run layerwise QAIC prefill + decode from a base path.") - parser.add_argument("base_path", type=Path, help="Path to onnx_layerwise_tmp (contains layer_*/...)") + parser.add_argument("base_path", type=Path, help="Path to onnx layer wise without onnx_layerwise_tmp ") parser.add_argument("--model-name", default="moonshotai/Kimi-K2.5") parser.add_argument("--prompt", default="Help") parser.add_argument("--max-len", type=int, default=32) diff --git a/dbg.log b/dbg.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/run.py b/run.py index 5f423de5df..8fb585945e 100644 --- a/run.py +++ b/run.py @@ -25,7 +25,7 @@ EXPORT_START = 1 EXPORT_END = 3 -LAYERWISE_MODE = "single_qpc" +LAYERWISE_MODE = "multiple_qpc" def _ensure_pretrained_window_attrs(): @@ -217,7 +217,7 @@ def main(): export_root = _resolve_export_root(first_onnx_path) if LAYERWISE_MODE == "multiple_qpc": - QEfficient.utils.compile_layerwise(str(export_root)) + # QEfficient.utils.compile_layerwise(str(export_root)) QEfficient.utils.inference_pipeline(str(export_root)) else: QEfficient.utils.layerwise_pipeline(str(export_root)) From 30c8e245dcd8e04b2510e12e581e756ff56f7b16 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Tue, 28 Apr 2026 18:29:22 +0530 Subject: [PATCH 49/51] Made minor fix Signed-off-by: Abhishek Kumar Singh --- dbg.log | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 dbg.log diff --git a/dbg.log b/dbg.log deleted file mode 100644 index e69de29bb2..0000000000 From b211a9a60272138ee9a50a046a5cb57920172ad5 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Wed, 29 Apr 2026 03:07:13 +0530 Subject: [PATCH 50/51] Made minor fix Signed-off-by: Abhishek Kumar Singh --- run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run.py b/run.py index 8fb585945e..f950d62e51 100644 --- a/run.py +++ b/run.py @@ -217,7 +217,7 @@ def main(): export_root = _resolve_export_root(first_onnx_path) if LAYERWISE_MODE == "multiple_qpc": - # QEfficient.utils.compile_layerwise(str(export_root)) + QEfficient.utils.compile_layerwise(str(export_root)) QEfficient.utils.inference_pipeline(str(export_root)) else: QEfficient.utils.layerwise_pipeline(str(export_root)) From 106f0654ff9efbb8677690a86c217b3c6d5434a9 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Wed, 29 Apr 2026 13:19:29 +0530 Subject: [PATCH 51/51] Added thread pool for loading the QPC Signed-off-by: Abhishek Kumar Singh --- QEfficient/utils/inference_pipeline.py | 108 ++++++++++++++++--------- 1 file changed, 72 insertions(+), 36 deletions(-) diff --git a/QEfficient/utils/inference_pipeline.py b/QEfficient/utils/inference_pipeline.py index 925a429e00..f7f6b64848 100644 --- a/QEfficient/utils/inference_pipeline.py +++ b/QEfficient/utils/inference_pipeline.py @@ -2,14 +2,17 @@ import argparse import re +import time +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import numpy as np from transformers import AutoTokenizer from QEfficient.generation.cloud_infer import QAICInferenceSession +SessionInfo = Dict[str, object] LAYER_DIR_RE = re.compile(r"layer_(\d+)_(\d+)$") @@ -86,12 +89,63 @@ def output_placeholder(session: QAICInferenceSession, output_name: str) -> np.nd return np.zeros(shape, dtype=dtype) +def resolve_base_path(base_path: str | Path) -> Path: + base = Path(base_path) + if (base / "onnx_layerwise_tmp").is_dir(): + return base + children = sorted( + p for p in base.iterdir() if p.is_dir() and (p / "onnx_layerwise_tmp").is_dir() + ) + if len(children) == 1: + return children[0] + if not children: + raise FileNotFoundError(f"No onnx_layerwise_tmp under: {base}") + raise RuntimeError( + f"Multiple candidate model directories under {base}. Pass one of: {[str(p) for p in children]}" + ) + + +def load_single_session(idx: int, qpc: Path, device_start: Optional[int]) -> Tuple[int, SessionInfo]: + device_ids = [device_start + idx] if device_start is not None else None + session = QAICInferenceSession(str(qpc), device_ids=device_ids) + session.skip_buffers( + [n for n in session.input_names + session.output_names if "compressed_kv" in n or "k_pe" in n] + ) + + out_name = pick_main_output_name(session) + session.set_buffers({out_name: output_placeholder(session, out_name)}) + + return idx, { + "session": session, + "token_input": pick_token_input_name(session), + "hidden_input": pick_hidden_input_name(session), + "pos_input": pick_pos_input_name(session), + "out_name": out_name, + } + + +def load_sessions_threaded( + qpc_paths: List[Path], device_start: Optional[int], max_workers: Optional[int] +) -> List[SessionInfo]: + worker_count = max_workers if max_workers is not None else min(64, len(qpc_paths) or 1) + indexed: List[Optional[SessionInfo]] = [None] * len(qpc_paths) + with ThreadPoolExecutor(max_workers=worker_count) as executor: + futures = [executor.submit(load_single_session, idx, qpc, device_start) for idx, qpc in enumerate(qpc_paths)] + for future in futures: + idx, info = future.result() + indexed[idx] = info + print(f"[LOAD] layer {idx}: {qpc_paths[idx]} -> out={info['out_name']}") + + return [info for info in indexed if info is not None] + + def inference_pipeline( base_path: str | Path, model_name: str = "moonshotai/Kimi-K2.5", - prompt: str = "Help", - max_len: int = 32, + prompt: str = "Help me with this", + max_len: int = 1000, device_start: Optional[int] = None, + max_workers: Optional[int] = None, ) -> List[int]: tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -101,31 +155,14 @@ def inference_pipeline( prompt_ids = tokenizer(prompt, return_tensors="np", add_special_tokens=True)["input_ids"][0].tolist() all_ids = list(prompt_ids) - qpc_paths = discover_qpc_paths(Path(base_path + "/onnx_layerwise_tmp")) + resolved_base = resolve_base_path(base_path) + qpc_paths = discover_qpc_paths(resolved_base / "onnx_layerwise_tmp") print(f"[LOAD] Found {len(qpc_paths)} layer sessions") - sessions: List[Dict[str, object]] = [] - for i, qpc in enumerate(qpc_paths): - device_ids = [device_start + i] if device_start is not None else None - session = QAICInferenceSession(str(qpc), device_ids=device_ids) - session.skip_buffers( - [n for n in session.input_names + session.output_names if "compressed_kv" in n or "k_pe" in n] - ) - - out_name = pick_main_output_name(session) - session.set_buffers({out_name: output_placeholder(session, out_name)}) - - sessions.append( - { - "session": session, - "token_input": pick_token_input_name(session), - "hidden_input": pick_hidden_input_name(session), - "pos_input": pick_pos_input_name(session), - "out_name": out_name, - } - ) - print(f"[LOAD] layer {i}: {qpc} -> out={out_name}") + start = time.time() + sessions = load_sessions_threaded(qpc_paths, device_start, max_workers) + print(f"[LOAD] Total load time: {time.time() - start:.2f}s") if not sessions: raise RuntimeError("No sessions loaded") if sessions[0]["token_input"] is None: @@ -133,7 +170,6 @@ def inference_pipeline( logits = None - # Prefill: pass each prompt token through all layers for pos, token_id in enumerate(prompt_ids): hidden = None for i, info in enumerate(sessions): @@ -158,7 +194,8 @@ def inference_pipeline( if logits is None: raise RuntimeError("Prompt produced no logits") - # Decode + start = time.time() + print("[RUN] Starting inference pipeline") generated_ids: List[int] = [] while len(all_ids) < max_len: next_token_id = int(np.argmax(logits, axis=-1)[0, 0]) @@ -189,6 +226,7 @@ def inference_pipeline( hidden = outputs[info["out_name"]] logits = hidden + print(f"[RUN] Total inference time: {time.time() - start:.2f}s") print("Generated token ids:") print(generated_ids) print("Generated text:") @@ -196,30 +234,28 @@ def inference_pipeline( return generated_ids -def inference_pipelines(base_path: str | Path) -> List[int]: - # Backward-compatible wrapper used by some local scripts. - return inference_pipeline(base_path=base_path) - - def main() -> None: - parser = argparse.ArgumentParser(description="Run layerwise QAIC prefill + decode from a base path.") - parser.add_argument("base_path", type=Path, help="Path to onnx layer wise without onnx_layerwise_tmp ") + parser = argparse.ArgumentParser(description="Threaded layerwise QAIC pipeline") + parser.add_argument("base_path", type=Path, help="Path to model dir or parent of model dir") parser.add_argument("--model-name", default="moonshotai/Kimi-K2.5") - parser.add_argument("--prompt", default="Help") + parser.add_argument("--prompt", default="Help me with this") parser.add_argument("--max-len", type=int, default=32) parser.add_argument( "--device-start", type=int, - default=None, + default=0, help="Optional starting device id. If set, layer i uses device_start + i.", ) + parser.add_argument("--max-workers", type=int, default=None, help="Thread pool size for load") args = parser.parse_args() + inference_pipeline( base_path=args.base_path, model_name=args.model_name, prompt=args.prompt, max_len=args.max_len, device_start=args.device_start, + max_workers=args.max_workers, )