From 44296803940df111c03d45d2207a9353d89f67c2 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 5 May 2026 15:40:01 +0530 Subject: [PATCH 01/17] port int4 changes Signed-off-by: Mamta Singh --- QEfficient/base/onnx_transforms.py | 2 + QEfficient/base/pytorch_transforms.py | 12 +- QEfficient/customop/matmulnbits.py | 88 ++++- QEfficient/customop/quantization_ops.py | 149 ++++++++ .../models/deepseek_v3/modeling_deepseek.py | 352 ++++++++++++++++-- .../transformers/models/modeling_auto.py | 2 + .../transformers/models/pytorch_transforms.py | 16 +- .../quantizers/quant_transforms.py | 40 +- .../quantizer_compressed_tensors.py | 160 ++++++-- 9 files changed, 751 insertions(+), 70 deletions(-) create mode 100644 QEfficient/customop/quantization_ops.py diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index c27e3cc704..32bb0a8fa4 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -39,6 +39,7 @@ CtxScatterFuncCB, CtxScatterFuncCB3D, ) +from QEfficient.customop.quantization_ops import CastToUInt4, CastToUInt4Func from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc from QEfficient.utils.constants import FILE_CHUNK_SIZE_DEFAULT, ONNX_EXPORT_OPSET, SIZE_THRESHOLD_DEFAULT @@ -100,6 +101,7 @@ class CustomOpTransform(BaseOnnxTransform): "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB), "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), + "CastToUInt4": (CastToUInt4Func, CastToUInt4), } @classmethod diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index 812177eac2..a716805d36 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -5,7 +5,7 @@ # # ---------------------------------------------------------------------------- from types import MethodType -from typing import Callable, Dict, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type from torch import nn @@ -97,6 +97,7 @@ class ModuleMutatorTransform(PytorchTransform): """ _match_class: nn.Module + _match_string: Optional[str] = None @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: @@ -135,7 +136,14 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__) ): for orig_method_name, mapped_method in repl_method_map.items(): - setattr(module, orig_method_name, MethodType(mapped_method, module)) + parts = orig_method_name.split(".") + if len(parts) > 1: + target = module + for part in parts[:-1]: + target = getattr(target, part) + setattr(target, parts[-1], MethodType(mapped_method, target)) + else: + setattr(module, orig_method_name, MethodType(mapped_method, module)) if hasattr(module, "__qeff_init__"): module.__qeff_init__() diff --git a/QEfficient/customop/matmulnbits.py b/QEfficient/customop/matmulnbits.py index e6249b0ad3..468f32e641 100644 --- a/QEfficient/customop/matmulnbits.py +++ b/QEfficient/customop/matmulnbits.py @@ -14,7 +14,10 @@ class QuantLinearTorchFunction(torch.autograd.Function): @staticmethod def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): - input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) + if qself_qzeros is None: + input_tuple = (x, qself_qweight, qself_scales) + else: + input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) input_tuple += (g_idx,) if g_idx is not None else () return g.op( "com.microsoft::MatMulNBits", @@ -28,7 +31,10 @@ def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group @staticmethod def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): + if qself_qzeros is None: + qself_qzeros = 2 ^ (bits - 1) if torch.onnx.is_in_onnx_export(): + # For faster export return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float() fp_weight = dequantize_blockwise_bits( qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features @@ -40,8 +46,7 @@ def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, grou def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols): if bits != 4: raise ValueError("Only bits=4 is supported for executing quantized model") - if group_size != 128: - raise ValueError("Only group_size=128 is supported for executing quantized model") + expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1) aligned_scale = scale.reshape(*quant_values.shape[:-1], 1) @@ -88,20 +93,20 @@ def __init__(self, bits, group_size, in_features, out_features, bias): q_rows = in_features // self.group_size self.register_buffer( "qweight", - torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8), + torch.empty((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8), ) self.register_buffer( "qzeros", - torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), + torch.empty((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), ) self.register_buffer( - "scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16) + "scales", torch.empty((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16) ) self.register_buffer( "g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32) ) if bias: - self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16)) + self.register_buffer("bias", torch.empty((out_features), dtype=torch.float16)) else: self.bias = None @@ -180,3 +185,72 @@ def forward(self, inputs): ) out = out + self.bias if self.bias is not None else out return out + + +class QMOE(torch.autograd.Function): + @staticmethod + def symbolic( + g, + x, + router_weights, + fc1_experts_weights, + fc1_scales, + fc2_experts_weights, + fc2_scales, + fc3_experts_weights, + fc3_scales, + router_probs, + activation_type, + block_size, + expert_weight_bits, + k, + ): + qmoe_out = g.op( + "com.microsoft::QMoE", + x, + router_weights, + router_probs, + fc1_experts_weights, + fc1_scales, + fc2_experts_weights, + fc2_scales, + fc3_experts_weights, + fc3_scales, + outputs=1, + activation_type_s=activation_type, # <-- _s suffix for string + block_size_i=block_size, + expert_weight_bits_i=expert_weight_bits, + k_i=k, + ) + + # # Create axes=-1 as an explicit int64 constant tensor + # axes = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + + # # Compute mean of router_probs along the last axis, keepdims for broadcasting + # router_probs_mean = g.op("ReduceMean", router_probs, axes, keepdims_i=1) + + # Multiply qmoe_out with the averaged router_probs + return qmoe_out + # return g.op("Mul", qmoe_out, router_probs_mean)) + + @staticmethod + def forward( + ctx, + x, + router_weights, + fc1_experts_weights, + fc1_scales, + fc2_experts_weights, + fc2_scales, + fc3_experts_weights, + fc3_scales, + router_probs, + activation_type, + block_size, + expert_weight_bits, + k, + ): + # Dummy forward: simulate qmoe_out as zeros_like(x), then apply ReduceMean * Mul + qmoe_out = torch.zeros_like(x) + router_probs_mean = router_probs.mean(dim=-1, keepdim=True) + return qmoe_out * router_probs_mean diff --git a/QEfficient/customop/quantization_ops.py b/QEfficient/customop/quantization_ops.py new file mode 100644 index 0000000000..3878804c7d --- /dev/null +++ b/QEfficient/customop/quantization_ops.py @@ -0,0 +1,149 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import onnxscript +import torch +from onnx import TensorProto + +from QEfficient.utils import constants + +ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) + + +@onnxscript.script(onnxscript.values.Opset("com.qti.aisw.onnx", 1)) +def CastToUInt4(weight_packed: onnxscript.UINT8) -> onnxscript.UINT8: + """ + Unpack packed uint8 weights into uint4 values and cast output to UINT4. + Supports N-D input: all leading dimensions are preserved; only the last + dimension (in_features // 2) is doubled to (in_features). + + Input: (..., in_features // 2) UINT8 + Each byte holds two nibbles: byte = (w_y << 4) | (w_x & 0x0F) + Output: (..., in_features) UINT4, values in [0, 15] + + Operations: + w_x = weight_packed % 16 (lower nibble) + w_y = (weight_packed >> 4) % 16 (upper nibble) + stacked = concat([w_x, w_y], axis=-1) after unsqueeze + → (..., in//2, 2) + leading_dims = shape[:-1] + new_shape = [...leading_dims, last_dim * 2] + reshaped = reshape(stacked, new_shape) + output = Cast(reshaped, to=UINT4) + """ + sixteen = ops.CastLike(ops.Constant(value_ints=[16]), weight_packed) + + # Lower nibble: weight_packed & 0x0F = weight_packed % 16 + w_x = ops.Mod(weight_packed, sixteen) + + # Upper nibble: (weight_packed >> 4) & 0x0F + shift = ops.CastLike(ops.Constant(value_ints=[4]), weight_packed) + w_shifted = ops.BitShift(weight_packed, shift, direction="RIGHT") + w_y = ops.Mod(w_shifted, sixteen) + + # Stack along a new last dim → (..., in_features//2, 2) + w_x_unsq = ops.Unsqueeze(w_x, [-1]) + w_y_unsq = ops.Unsqueeze(w_y, [-1]) + stacked = ops.Concat(w_x_unsq, w_y_unsq, axis=-1) + + # N-D aware reshape: preserve all leading dims, double the last dim. + # packed_shape = [d0, d1, ..., last_dim] + packed_shape = ops.Shape(weight_packed) + # All dims except the last: [d0, d1, ...] + leading_dims = ops.Slice(packed_shape, starts=[0], ends=[-1], axes=[0]) + # Last dim only: [last_dim] + last_dim = ops.Slice(packed_shape, starts=[-1], ends=[2147483647], axes=[0]) + # Double the last dim: [last_dim * 2] + last_dim_doubled = ops.Mul(last_dim, ops.Constant(value_ints=[2])) + # New shape: [d0, d1, ..., last_dim * 2] + new_shape = ops.Concat(leading_dims, last_dim_doubled, axis=0) + reshaped = ops.Reshape(stacked, new_shape) + + # Cast to UINT4 — data_type value is version-dependent (21 in ONNX 1.18, 23 in newer) + return ops.Cast(reshaped, to=int(TensorProto.UINT4)) + + +class CastToUInt4Func(torch.autograd.Function): + """ + Custom op: unpacks packed uint8 → uint8 (values 0-15) in PyTorch. + In ONNX the custom op subgraph includes a Cast → UINT4 as its last step. + Supports N-D input: all leading dimensions are preserved. + + PyTorch forward : packed uint8 (..., in//2) → uint8 (..., in), values [0, 15] + ONNX symbolic : emits CastToUInt4 node (com.qti.aisw.onnx) + The subgraph ends with Cast → UINT4. + """ + + @staticmethod + def forward(weight_packed: torch.Tensor) -> torch.Tensor: + w_x = weight_packed & 0x0F # lower nibble, (..., in//2), range [0, 15] + w_y = (weight_packed >> 4) & 0x0F # upper nibble, (..., in//2), range [0, 15] + # New shape: all leading dims unchanged, last dim doubled + new_shape = list(weight_packed.shape[:-1]) + [weight_packed.shape[-1] * 2] + return torch.stack( + [w_x, w_y], dim=-1 + ).reshape( + new_shape + ) # Can't add a cast operation to uint4 here, as its not supported in pytorch; The ONNX export will handle the cast to IINT4 in the symbolic method. + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, weight_packed: torch.Value) -> torch.Value: + output = g.onnxscript_op(CastToUInt4, weight_packed) + return output + + +class DequantizeLinearFunc(torch.autograd.Function): + """ + Emits a standard ONNX DequantizeLinear node (ai.onnx domain, not custom). + + Symmetric blockwise quantization — no zero_point: + output = x * scale (per block along the last axis) + + Supports N-D input: + weight_unpacked : (..., in_features) — quantized values + scale : (..., num_blocks) — per-block scales + block_size : int — elements per block + + PyTorch forward : expand blockwise scale along last dim, multiply + ONNX symbolic : DequantizeLinear(weight_unpacked, scale, + axis=2, block_size=block_size) + axis=2 for 3D input (2, out_features, in_features). + No zero_point input (symmetric). + """ + + @staticmethod + def forward( + weight_unpacked: torch.Tensor, scale: torch.Tensor, zeros: torch.Tensor, block_size: int + ) -> torch.Tensor: + # Expand per-block scale → per-element scale along last dim + scale_expanded = scale.repeat_interleave(block_size, dim=-1) + zeros_expanded = zeros.repeat_interleave(block_size, dim=-1) + return (weight_unpacked.to(torch.int8) - zeros_expanded.to(torch.int8)) * scale_expanded + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic( + g: torch.Graph, weight_unpacked: torch.Value, scale: torch.Value, zeros: torch.Value, block_size: int + ) -> torch.Value: + # Standard DequantizeLinear: symmetric (no zero_point), blockwise. + # Input is 3D: (2, out_features, in_features) → axis=2 (last dim). + # DequantizeLinear natively supports batch dimensions. + return g.op( + "DequantizeLinear", + weight_unpacked, + scale, + zeros, + axis_i=2, + block_size_i=block_size, + ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 33dcd6392b..925ac8de65 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -21,6 +21,8 @@ generic_blocked_mla_attention_interface, ) from QEfficient.customop.rms_norm import CustomRMSNormFunc +from QEfficient.customop.matmulnbits import QMOE, QuantLinearTorchFunction +from QEfficient.customop.quantization_ops import CastToUInt4Func, DequantizeLinearFunc from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MAX_POSITION_EMBEDDINGS, MIN_MASKED_ATTENTION_VALUE @@ -470,7 +472,13 @@ def fused_forward_orig( k_pe_expanded = k_pe_expanded[:, :q_heads, :, :] else: kva_expanded = kva - k_pe_expanded = k_pe + #k_pe_expanded = k_pe + num_heads_to_repeat = math.ceil(q_heads / k_heads) + k_pe_expanded = ( + k_pe.unsqueeze(2) + .expand(-1, -1, num_heads_to_repeat, -1, -1) + .reshape(bsz, num_heads_to_repeat * k_heads, -1, self.config.qk_rope_head_dim) + ) v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) value_states = torch.matmul(kva_expanded, v_up_per_head) @@ -749,31 +757,143 @@ def forward( **kwargs, ) +class QEffDeepseekMoEGate(nn.Module): + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + if self.topk_method == "noaux_tc": + assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + router_probs = tmp_scores + router_weights = scores + return topk_idx, topk_weight, router_probs, router_weights + class QEffDeepseekV3MoE(nn.Module): def __qeff_init__( self, ): - self.all_gate_proj = torch.nn.Parameter( - torch.cat( - [exp.gate_proj.compressor.decompress_module(exp.gate_proj).T.unsqueeze(0) for exp in self.experts], - dim=0, - ) + # Get common parameters from first expert + first_expert = self.experts[0] + self.bits = first_expert.gate_proj.bits + self.group_size = first_expert.gate_proj.group_size + self.act_fn = first_expert.act_fn + assert first_expert.gate_proj.act_order == first_expert.up_proj.act_order == first_expert.down_proj.act_order, ( + "act_order mismatch" ) - self.all_up_proj = torch.nn.Parameter( - torch.cat( - [exp.up_proj.compressor.decompress_module(exp.up_proj).T.unsqueeze(0) for exp in self.experts], dim=0 - ) + self.act_order = first_expert.gate_proj.act_order + + # Store dimensions for dequantization + self.in_features_gate, self.out_features_gate = ( + first_expert.gate_proj.in_features, + first_expert.gate_proj.out_features, ) - self.all_down_proj = torch.nn.Parameter( - torch.cat( - [exp.down_proj.compressor.decompress_module(exp.down_proj).T.unsqueeze(0) for exp in self.experts], - dim=0, - ) + self.in_features_up, self.out_features_up = first_expert.up_proj.in_features, first_expert.up_proj.out_features + self.in_features_down, self.out_features_down = ( + first_expert.down_proj.in_features, + first_expert.down_proj.out_features, ) - self.act_fn = self.experts[0].act_fn - def moe( + # Stack all parameters along a new dimension (expert dimension) + self.all_gate_qweight = torch.nn.Parameter( + torch.stack([exp.gate_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // 2 + ), + requires_grad=False, + ) + self.all_gate_scales = torch.nn.Parameter( + torch.stack([exp.gate_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // self.group_size + ), + requires_grad=False, + ) + # TODO: Since we know qzeros is always 8 -> Just embed this once into the operator as parameter -> explore this later + self.all_gate_qzeros = torch.nn.Parameter( + torch.stack([exp.gate_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_gate_gidx = torch.nn.Parameter( + torch.stack([exp.gate_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + self.all_up_qweight = torch.nn.Parameter( + torch.stack([exp.up_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // 2 + ), + requires_grad=False, + ) + self.all_up_scales = torch.nn.Parameter( + torch.stack([exp.up_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // self.group_size + ), + requires_grad=False, + ) + self.all_up_qzeros = torch.nn.Parameter( + torch.stack([exp.up_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_up_gidx = torch.nn.Parameter( + torch.stack([exp.up_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + self.all_down_qweight = torch.nn.Parameter( + torch.stack([exp.down_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // 2 + ), + requires_grad=False, + ) + self.all_down_scales = torch.nn.Parameter( + torch.stack([exp.down_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // self.group_size + ), + requires_grad=False, + ) + self.all_down_qzeros = torch.nn.Parameter( + torch.stack([exp.down_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_down_gidx = torch.nn.Parameter( + torch.stack([exp.down_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + def moe_old( self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, @@ -783,33 +903,199 @@ def moe( hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - gate_proj = self.all_gate_proj[topk_indices.flatten()] - up_proj = self.all_up_proj[topk_indices.flatten()] - down_proj = self.all_down_proj[topk_indices.flatten()] + for i in range(self.gate.top_k): + expert_idx = topk_indices[:, i] + curr_weight = topk_weights[:, i] + gate_qweight = self.all_gate_qweight[expert_idx].reshape( + seq_len * self.out_features_gate, + self.in_features_gate // self.group_size, + (self.group_size * self.bits) // 8, + ) + gate_scales = self.all_gate_scales[expert_idx].reshape( + seq_len * self.out_features_gate * (self.in_features_gate // self.group_size) + ) + gate_qzeros = self.all_gate_qzeros[expert_idx].reshape( + seq_len * self.out_features_gate, self.in_features_gate // self.group_size + ) + gate_gidx = self.all_gate_gidx[expert_idx].reshape(seq_len * self.in_features_gate) + + up_qweight = self.all_up_qweight[expert_idx].reshape( + seq_len * self.out_features_up, + self.in_features_up // self.group_size, + (self.group_size * self.bits) // 8, + ) + up_scales = self.all_up_scales[expert_idx].reshape( + seq_len * self.out_features_up * (self.in_features_up // self.group_size) + ) + up_qzeros = self.all_up_qzeros[expert_idx].reshape( + seq_len * self.out_features_up, self.in_features_up // self.group_size + ) + up_gidx = self.all_up_gidx[expert_idx].reshape(seq_len * self.in_features_up) + + down_qweight = self.all_down_qweight[expert_idx].reshape( + seq_len * self.out_features_down, + self.in_features_down // self.group_size, + (self.group_size * self.bits) // 8, + ) + down_scales = self.all_down_scales[expert_idx].reshape( + seq_len * self.out_features_down * (self.in_features_down // self.group_size) + ) + down_qzeros = self.all_down_qzeros[expert_idx].reshape( + seq_len * self.out_features_down, self.in_features_down // self.group_size + ) + down_gidx = self.all_down_gidx[expert_idx].reshape(seq_len * self.in_features_down) + + gate_out = QuantLinearTorchFunction.apply( + hidden_states, + gate_qweight, + gate_scales, + gate_qzeros, + gate_gidx if self.act_order else None, + self.bits, + self.group_size, + self.in_features_gate, + self.out_features_gate * seq_len, + ) + + up_out = QuantLinearTorchFunction.apply( + hidden_states, + up_qweight, + up_scales, + up_qzeros, + up_gidx if self.act_order else None, + self.bits, + self.group_size, + self.in_features_up, + self.out_features_up * seq_len, + ) + + hidden = self.act_fn(gate_out) * up_out + down_out = QuantLinearTorchFunction.apply( + hidden, + down_qweight, + down_scales, + down_qzeros, + down_gidx if self.act_order else None, + self.bits, + self.group_size, + self.in_features_down, + self.out_features_down, + ) + down_out = down_out.reshape(seq_len, self.out_features_down) + final_hidden_states += down_out * curr_weight.unsqueeze(1) + + return final_hidden_states + + def moe_weights_as_activations(self, hidden_states, router_probs, router_weights): + return QMOE.apply( + hidden_states, + router_weights, + self.fc1_experts_weights, + self.fc1_scales, + self.fc2_experts_weights, + self.fc2_scales, + self.fc3_experts_weights, + self.fc3_scales, + router_probs, + self.config.hidden_act, + self.group_size, + self.bits, + self.num_experts_per_tok, + ) + + @torch.no_grad() + def original_moe(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + # sorted_tokens_shape = sorted_tokens.shape + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + def moe_waa_unpack(self, hidden_states, topk_indices, topk_weights): + # GATHER - collect weights for selected experts + gate_proj_qweight = self.all_gate_qweight[topk_indices.flatten()] + gate_proj_scales = self.all_gate_scales[topk_indices.flatten()] + gate_proj_qzeros = self.all_gate_qzeros[topk_indices.flatten()] + + up_proj_qweight = self.all_up_qweight[topk_indices.flatten()] + up_proj_scales = self.all_up_scales[topk_indices.flatten()] + up_proj_qzeros = self.all_up_qzeros[topk_indices.flatten()] + + down_proj_qweight = self.all_down_qweight[topk_indices.flatten()] + down_proj_scales = self.all_down_scales[topk_indices.flatten()] + down_proj_qzeros = self.all_down_qzeros[topk_indices.flatten()] + + gate_proj_unpacked = CastToUInt4Func.apply(gate_proj_qweight) + gate_zeros_unpacked = CastToUInt4Func.apply(gate_proj_qzeros) + gate_proj_dq = DequantizeLinearFunc.apply( + gate_proj_unpacked, gate_proj_scales, gate_zeros_unpacked, self.group_size + ) + + up_proj_unpacked = CastToUInt4Func.apply(up_proj_qweight) + up_zeros_unpacked = CastToUInt4Func.apply(up_proj_qzeros) + up_proj_dq = DequantizeLinearFunc.apply(up_proj_unpacked, up_proj_scales, up_zeros_unpacked, self.group_size) + + down_proj_unpacked = CastToUInt4Func.apply(down_proj_qweight) + down_zeros_unpacked = CastToUInt4Func.apply(down_proj_qzeros) + down_proj_dq = DequantizeLinearFunc.apply( + down_proj_unpacked, down_proj_scales, down_zeros_unpacked, self.group_size + ) + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) expert_in = ( - hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size) + hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.in_features_gate) ) - gate_out = torch.bmm(expert_in, gate_proj) - up_out = torch.bmm(expert_in, up_proj) + + gate_out = torch.bmm(expert_in, gate_proj_dq.transpose(1, 2)) + up_out = torch.bmm(expert_in, up_proj_dq.transpose(1, 2)) hidden = self.act_fn(gate_out) * up_out - expert_output = torch.bmm(hidden, down_proj) - experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) - experts_out = experts_out * topk_weights.unsqueeze(-1) + down_out = torch.bmm(hidden, down_proj_dq.transpose(1, 2)) - final_hidden_states = torch.einsum("abc->ac", experts_out) + down_out = down_out.view(-1, self.gate.top_k, self.out_features_down) - return final_hidden_states.type(hidden_states.dtype) + down_out = down_out * topk_weights.unsqueeze(-1) + + return torch.einsum("abc-> ac", down_out) def forward(self, hidden_states): + print("Using new MoE forward with weights as activations") residuals = hidden_states orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) + topk_indices, topk_weights, router_probs, router_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + # hidden_states = self.moe_weights_as_activations(hidden_states, router_probs, router_weights).view(*orig_shape) + hidden_states = self.moe_waa_unpack(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states - +''' class QEffPrefillOnlyDeepseekV3MoE(nn.Module): def __qeff_init__( self, @@ -872,7 +1158,7 @@ def forward(self, hidden_states): """ residuals = hidden_states orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) + topk_indices, topk_weights, _, _ = self.gate(hidden_states) # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -891,7 +1177,7 @@ def forward(self, hidden_states): hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states - +''' class QEffDeepseekV3DecoderLayer(nn.Module): """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c5c50f1c7d..50bb99e40f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -70,6 +70,7 @@ FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, Mxfp4GptOssExpertDequantizeTransform, + PackQuantizedInt4ToMatMulNBitsTransform, ) from QEfficient.utils import ( constants, @@ -2743,6 +2744,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, + PackQuantizedInt4ToMatMulNBitsTransform, Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5ff06e6443..a1ea859545 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -250,13 +250,14 @@ QEffDisentangledSelfAttention, ) from QEfficient.transformers.models.deepseek_v3.modeling_deepseek import ( + QEffDeepseekMoEGate, QEffDeepseekV3Attention, QEffDeepseekV3CustomRMSNormAIC, QEffDeepseekV3DecoderLayer, QEffDeepseekV3ForCausalLM, QEffDeepseekV3Model, QEffDeepseekV3MoE, - QEffPrefillOnlyDeepseekV3MoE, + #QEffPrefillOnlyDeepseekV3MoE, ) from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, @@ -1034,8 +1035,11 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, "DeepseekV3MoE": { "forward": QEffDeepseekV3MoE.forward, - "moe": QEffDeepseekV3MoE.moe, + "moe_weights_as_activations": QEffDeepseekV3MoE.moe_weights_as_activations, + "moe_waa_unpack": QEffDeepseekV3MoE.moe_waa_unpack, + "original_moe": QEffDeepseekV3MoE.original_moe, "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, + "gate.forward": QEffDeepseekMoEGate.forward, }, "DeepseekV3Attention": { "forward": QEffDeepseekV3Attention.forward, @@ -1057,9 +1061,9 @@ class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} _match_string_replace_method = { "DeepseekV3MoE": { - "forward": QEffPrefillOnlyDeepseekV3MoE.forward, - "moe": QEffPrefillOnlyDeepseekV3MoE.moe, - "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + "forward": QEffDeepseekV3MoE.forward, + #"moe": QEffPrefillOnlyDeepseekV3MoE.moe, + #"__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, }, } @@ -1069,7 +1073,7 @@ class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransfo _match_string_replace_method = { "DeepseekV3MoE": { "forward": QEffDeepseekV3MoE.forward, - "moe": QEffDeepseekV3MoE.moe, + "moe": QEffDeepseekV3MoE.moe_waa_unpack, "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, }, } diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index f97bfe998e..cd4fc98bf7 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- import torch +from compressed_tensors.compressors import PackedQuantizationCompressor +from compressed_tensors.linear.compressed_linear import CompressedLinear from torch import nn from transformers import AutoConfig from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts @@ -53,7 +55,6 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module): original_module.bits, original_module.group_size, ) - original_module.weight = fp16_weight new_module = QuantLinearORT( original_module.bits, @@ -67,6 +68,43 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module): return new_module +class PackQuantizedInt4ToMatMulNBitsTransform(ModuleMutatorTransform): + """ + This transform is used to pack the quantized int4 weights into a format that can be used by the MatMulNBits kernel. + It is used for the ONNX export of the quantized model. + """ + + _match_class = CompressedLinear + + @classmethod + def mutate(cls, original_module, parent_module): + # add compressor.decompress to get the decompressed weight + # and then package into matmulnbit + assert isinstance(original_module.compressor, PackedQuantizationCompressor), ( + f"Only {PackedQuantizationCompressor} supported for now" + ) + fp_weight = original_module.compressor.decompress_module(original_module) + scales = original_module.weight_scale + # assuming symmetric quantization + quantization_args = original_module.quantization_scheme.weights + zeros = (torch.zeros_like(scales) + pow(2, (quantization_args.num_bits - 1))).to(torch.uint8) + g_idx = torch.arange(original_module.in_features // quantization_args.group_size).repeat_interleave( + quantization_args.group_size + ) + original_module.weight = torch.nn.Parameter(fp_weight) + assert quantization_args.type == "int", "uint is not tested yet" + new_module = QuantLinearORT( + quantization_args.num_bits, + quantization_args.group_size, + original_module.in_features, + original_module.out_features, + original_module.bias is not None, + ) + new_module.bias = original_module.bias if original_module.bias is not None else None + new_module.pack(original_module, scales, zeros, g_idx) + return new_module + + class GPTQToMatmulNbitsTransform(ModuleMutatorTransform): """ A transformation class that mutates a ``QuantLinearGPTQ`` module to a ``QuantLinearORT`` diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py index f7ecc5b218..0025b916a9 100644 --- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py +++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py @@ -12,6 +12,7 @@ import torch from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer +from transformers.utils import is_compressed_tensors_available from transformers.utils.quantization_config import CompressedTensorsConfig, QuantizationConfigMixin, QuantizationMethod from QEfficient.transformers.quantizers.quantizer_utils import blockwise_dequantize, get_keys_to_not_convert @@ -390,7 +391,55 @@ def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str class QEffCompressedTensorsConfig(CompressedTensorsConfig): - def __init__( + def handle_pack_quantized_init( + self, + config_groups=None, + format="dense", + quantization_status="initialized", + kv_cache_scheme=None, + global_compression_ratio=None, + ignore=None, + sparsity_config=None, + quant_method="compressed-tensors", + run_compressed: bool = True, + **kwargs, + ): + if is_compressed_tensors_available(): + from compressed_tensors.config import SparsityCompressionConfig + from compressed_tensors.quantization import QuantizationConfig + else: + raise ImportError( + "compressed_tensors is not installed and is required for compressed-tensors quantization. Please install it with `pip install compressed-tensors`." + ) + self.quantization_config = None + self.sparsity_config = None + + self.run_compressed = run_compressed + assert self.run_compressed, "pack-quantized needs to have run_compressed set to True" + + # parse from dict to load nested QuantizationScheme objects + if config_groups or kv_cache_scheme: + self.quantization_config = QuantizationConfig.model_validate( + { + "config_groups": config_groups, + "quant_method": quant_method, + "format": format, + "quantization_status": quantization_status, + "kv_cache_scheme": kv_cache_scheme, + "global_compression_ratio": global_compression_ratio, + "ignore": ignore, + **kwargs, + } + ) + + if sparsity_config: + self.sparsity_config = SparsityCompressionConfig.load_from_registry( + sparsity_config.get("format"), **sparsity_config + ) + + self.quant_method = QuantizationMethod.COMPRESSED_TENSORS + + def handle_fp8_init( self, config_groups=None, format="dense", @@ -480,7 +529,50 @@ def __init__( self.quant_method = QuantizationMethod.COMPRESSED_TENSORS + def __init__( + self, + config_groups=None, + format="dense", + quantization_status="initialized", + kv_cache_scheme=None, + global_compression_ratio=None, + ignore=None, + sparsity_config=None, + quant_method="compressed-tensors", + run_compressed: bool = None, + **kwargs, + ): + if format == "pack-quantized": + self.handle_pack_quantized_init( + config_groups=config_groups, + format=format, + quantization_status=quantization_status, + kv_cache_scheme=kv_cache_scheme, + global_compression_ratio=global_compression_ratio, + ignore=ignore, + sparsity_config=sparsity_config, + quant_method=quant_method, + run_compressed=True if run_compressed is None else run_compressed, + **kwargs, + ) + else: + self.handle_fp8_init( + config_groups=config_groups, + format=format, + quantization_status=quantization_status, + kv_cache_scheme=kv_cache_scheme, + global_compression_ratio=global_compression_ratio, + ignore=ignore, + sparsity_config=sparsity_config, + quant_method=quant_method, + run_compressed=False if run_compressed is None else run_compressed, + **kwargs, + ) + def to_dict(self): + if self.quantization_config.format == "pack-quantized": + return super().to_dict() + return { "quantization_config": { "config_groups": self.config_groups, @@ -501,39 +593,59 @@ def to_dict(self): class QEffCompressedTensorsFP8Quantizer(CompressedTensorsHfQuantizer): requires_calibration = False - def __init__(self, quantization_config, **kwargs): - # TODO: check if more checks are required - if not isinstance(quantization_config, QEffCompressedTensorsConfig): - raise TypeError( - f"Only {QEffCompressedTensorsConfig} is supported for initialization got {type(quantization_config)}" - ) - self.run_compressed = quantization_config.run_compressed - self.quantization_config = quantization_config - - # -- Handle extra kwargs below -- - self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) - self.modules_to_not_convert = list( - set(self.modules_to_not_convert if self.modules_to_not_convert else []) - | set(self.quantization_config.ignore if self.quantization_config.ignore else []) + @staticmethod + def is_pack_quantized(quant_config): + return ( + hasattr(quant_config, "quantization_config") + and hasattr(quant_config.quantization_config, "format") + and quant_config.quantization_config.format == "pack-quantized" ) - self.pre_quantized = kwargs.pop("pre_quantized", True) - if not self.pre_quantized and self.requires_calibration: - raise ValueError( - f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." - f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " - f"pass `pre_quantized=True` while knowing what you are doing." + def __init__(self, quantization_config, **kwargs): + if self.is_pack_quantized(quantization_config): + super().__init__(quantization_config, **kwargs) + else: + if not isinstance(quantization_config, QEffCompressedTensorsConfig): + raise TypeError( + f"Only {QEffCompressedTensorsConfig} is supported for initialization got {type(quantization_config)}" + ) + self.run_compressed = quantization_config.run_compressed + self.quantization_config = quantization_config + + # -- Handle extra kwargs below -- + self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) + self.modules_to_not_convert = list( + set(self.modules_to_not_convert if self.modules_to_not_convert else []) + | set(self.quantization_config.ignore if self.quantization_config.ignore else []) ) + self.pre_quantized = kwargs.pop("pre_quantized", True) + + if not self.pre_quantized and self.requires_calibration: + raise ValueError( + f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." + f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " + f"pass `pre_quantized=True` while knowing what you are doing." + ) def validate_environment(self, *args, **kwargs): + if self.is_pack_quantized(self.quantization_config): + return super().validate_environment(*args, **kwargs) + return True def update_torch_dtype(self, torch_dtype): + if self.is_pack_quantized(self.quantization_config): + return super().update_torch_dtype(torch_dtype) + if torch_dtype not in [None, torch.float32]: logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") return None def _process_model_before_weight_loading(self, model, **kwargs): + if self.is_pack_quantized(self.quantization_config): + super()._process_model_before_weight_loading(model, **kwargs) + return + if self.quantization_config.targets != ["Linear"]: raise NotImplementedError( f"Only Linear layer with FP8 quantization are supported got targets = {self.quantization_config.targets}" @@ -561,9 +673,15 @@ def replace_linear_with_fp8_dequant_layer(module): replace_linear_with_fp8_dequant_layer(model) def _process_model_after_weight_loading(self, model, **kwargs): + if self.is_pack_quantized(self.quantization_config): + super()._process_model_after_weight_loading(model, **kwargs) + return pass def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]: + if self.is_pack_quantized(self.quantization_config): + return super().update_missing_keys_after_loading(model, missing_keys=missing_keys, prefix=prefix) + return missing_keys def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str = None) -> List[str]: From 28ba2b0df4f028660ab95553a9de54578916007b Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 5 May 2026 15:41:21 +0530 Subject: [PATCH 02/17] prefill_changes Signed-off-by: Mamta Singh --- QEfficient/customop/__init__.py | 6 + QEfficient/customop/ctx_scatter_gather.py | 100 +++++++- .../models/deepseek_v3/modeling_deepseek.py | 213 ++++++++++++------ .../transformers/models/pytorch_transforms.py | 6 +- 4 files changed, 252 insertions(+), 73 deletions(-) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index 35830aa91e..4830e660c3 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -8,9 +8,12 @@ from QEfficient.customop.ctx_scatter_gather import ( CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, CtxScatterFunc, CtxScatterFunc3D, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncBlockedKVCB, @@ -26,7 +29,10 @@ "CtxGatherFuncBlockedKV", "CtxScatterFunc", "CtxGatherFunc3D", + "CtxGatherFunc3DGeneralized", "CtxScatterFunc3D", + "CtxScatterFunc3DGeneralized", + "CtxScatterFunc3DInt", "CustomRMSNormAIC", "GemmaCustomRMSNormAIC", "CtxGatherFuncCB", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 59bfe6af03..01085ac967 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates # Create indices batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + + # keep index tensor types aligned for backend that require exact dtype match + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) indices = ops.Concat(batch_idx, ctx_idx, axis=2) @@ -78,8 +81,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates class CtxScatterFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() batch_idx = torch.arange(data.shape[0]).view(-1, 1) - ctx_idx = position_ids + ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids) data[batch_idx, ctx_idx] = updates return data @@ -92,6 +96,74 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) +class CtxScatterFunc3DGeneralized(torch.autograd.Function): + """Scatter variant that preserves ``data`` at invalid (INT32_MAX) positions. + + Unlike :class:`CtxScatterFunc3D`, which writes updates for invalid rows to + ``data.shape[1]-1`` (potentially clobbering valid content), this version + masks out invalid rows before scattering so ``data`` is left untouched where + ``position_ids == INT32_MAX``. + """ + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatter3DInt( + data: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.INT32 +) -> onnxscript.INT32: + # Find dims + batch_size = ops.Gather(ops.Shape(data), [0]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, seq_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) + indices = ops.Concat(batch_idx, ctx_idx, axis=2) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFunc3DInt(torch.autograd.Function): + """Int32-typed scatter used to build a packed->original index table.""" + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3DInt, data, position_ids, updates).setTypeAs(data) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) @@ -103,6 +175,7 @@ class CtxGatherFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, ctx_indices] @staticmethod @@ -114,6 +187,31 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor return g.onnxscript_op(CtxGather3D, data, ctx_indices).setTypeAs(data) +class CtxGatherFunc3DGeneralized(torch.autograd.Function): + """Gather variant that tolerates INT32_MAX indices (invalid rows read from 0). + + Semantically equivalent to :class:`CtxGatherFunc3D` on the PyTorch side but + exposed as a separate autograd op so callers using the packed/cumsum scatter + pipeline can be easily recognized and so the ONNX symbolic omits + ``setTypeAs`` (needed when the caller already has a matching dtype on + ``data`` and wants the op signature to flow through without dtype pinning). + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) + return data[batch_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGather3D, data, ctx_indices) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather( data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 925ac8de65..d9f34ec2f1 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -20,6 +20,11 @@ generic_blocked_attention_interface, generic_blocked_mla_attention_interface, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.customop.rms_norm import CustomRMSNormFunc from QEfficient.customop.matmulnbits import QMOE, QuantLinearTorchFunction from QEfficient.customop.quantization_ops import CastToUInt4Func, DequantizeLinearFunc @@ -757,6 +762,31 @@ def forward( **kwargs, ) + +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) + + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this + # can be switched back to ``torch.full_like(token_idx, int32_max)``. + matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + class QEffDeepseekMoEGate(nn.Module): def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape @@ -1095,89 +1125,132 @@ def forward(self, hidden_states): hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states -''' + class QEffPrefillOnlyDeepseekV3MoE(nn.Module): - def __qeff_init__( + + def _cumsum_scatter_gather_update_expert_blocked( self, - ): - for exp in self.experts: - gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + x: torch.Tensor, + T2Ei: torch.Tensor, + expert, +# W_g: torch.Tensor, +# W_u: torch.Tensor, +# W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + act_fn, + T: int, + packed_chunk_size: int, + ) -> torch.Tensor: + """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. + + Accumulates one local expert's contribution in-place onto ``expert_out``. + Uses a packed/cumsum layout so the MLP runs only over active rows, then + scatters the weighted output back to original token positions. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) - gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) - up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) - down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) - setattr(exp, "gate_proj", gate_proj) - setattr(exp, "up_proj", up_proj) - setattr(exp, "down_proj", down_proj) + for packed_start in range(0, seq_len, packed_chunk_size): + packed_stop = packed_start + packed_chunk_size + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] - def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - for expert_idx in range(num_experts): + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate_prime = expert.gate_proj(x_chunk) + up_prime = expert.up_proj(x_chunk) + down_chunk = expert.down_proj((up_prime * act_fn(gate_prime))) + + #gate_prime = x_chunk @ W_g + #up_prime = x_chunk @ W_u + #down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = EXPERT_BLOCKING_NUM_NSP + if len(self.experts) % num_nsp != 0: + raise ValueError( + f"num_experts ({len(self.experts)}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + ) + local_experts = len(self.experts) // num_nsp + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() +# W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() +# W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() +# W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = self._cumsum_scatter_gather_update_expert_blocked( + x=x, + T2Ei=T2Ei, + expert=self.experts[slot], +# W_g=W_g[:, slot], +# W_u=W_u[:, slot], +# W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, + act_fn=self.experts[0].act_fn, + T=T, + packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, + ) + return expert_out.sum(dim=0) + + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + topk_idx, topk_weight, router_probs, router_weights = self.gate(hidden_states) + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + + routing_weights = torch.zeros(T, self.config.n_routed_experts) + routing_weights.scatter_(1, topk_idx, topk_weight) + + if len(self.experts) % EXPERT_BLOCKING_NUM_NSP == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H) + + final_hidden_states = x.new_zeros((T, H)) + for expert_idx in range(self.n_routed_experts): expert = self.experts[expert_idx] gate_out = expert.gate_proj(hidden_states) up_out = expert.up_proj(hidden_states) hidden = expert.act_fn(gate_out) * up_out expert_output = expert.down_proj(hidden) - current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) + current_hidden_states = expert_output * routing_weights[:, expert_idx].unsqueeze(-1) final_hidden_states += current_hidden_states - return final_hidden_states.type(hidden_states.dtype) - - def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - r""" - CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused - to not have to do a loop here (deepseek has 256 experts soooo yeah). - """ - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) - expert_mask = expert_mask.permute(2, 0, 1) - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) - - if token_indices.numel() > 0: - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, weighted_output) - - # in original deepseek, the output of the experts are gathered once we leave this module - # thus the moe module is itelsf an IsolatedParallel module - # and all expert are "local" meaning we shard but we don't gather - return final_hidden_states.type(hidden_states.dtype) + return final_hidden_states.view(B, S, H) - def forward(self, hidden_states): - """ - Forward pass of MoE block. - """ - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights, _, _ = self.gate(hidden_states) - # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) - mask.scatter_(1, topk_indices, topk_weights) - if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: - hidden_states = self.moe_blocked_weights_forward( - hidden_states, topk_weights, mask, self.config.n_routed_experts - ).view(*orig_shape) - elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: - hidden_states = self.moe_blocked_forward( - hidden_states, topk_weights, mask, self.config.n_routed_experts - ).view(*orig_shape) - else: - hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) - - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states -''' class QEffDeepseekV3DecoderLayer(nn.Module): """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index a1ea859545..61e44acdc8 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -257,7 +257,7 @@ QEffDeepseekV3ForCausalLM, QEffDeepseekV3Model, QEffDeepseekV3MoE, - #QEffPrefillOnlyDeepseekV3MoE, + QEffPrefillOnlyDeepseekV3MoE, ) from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, @@ -1061,9 +1061,11 @@ class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} _match_string_replace_method = { "DeepseekV3MoE": { - "forward": QEffDeepseekV3MoE.forward, + "forward": QEffPrefillOnlyDeepseekV3MoE.forward, #"moe": QEffPrefillOnlyDeepseekV3MoE.moe, #"__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + "_forward_expert_blocked": QEffPrefillOnlyDeepseekV3MoE._forward_expert_blocked, + "_cumsum_scatter_gather_update_expert_blocked": QEffPrefillOnlyDeepseekV3MoE._cumsum_scatter_gather_update_expert_blocked, }, } From c5d239742e4baa03c3abc6b574c68f0d5d352623 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 6 May 2026 17:07:23 +0530 Subject: [PATCH 03/17] use casttouint4 in prefill and example script for kimi-k2.5 Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek.py | 176 +++++++++++-- .../transformers/models/pytorch_transforms.py | 3 +- examples/kimi_k2/export_kimi_k25.py | 238 ++++++++++++++++++ 3 files changed, 390 insertions(+), 27 deletions(-) create mode 100644 examples/kimi_k2/export_kimi_k25.py diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index eb404d1587..2169f2bf0d 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1126,18 +1126,113 @@ def forward(self, hidden_states): class QEffPrefillOnlyDeepseekV3MoE(nn.Module): + def __qeff_init__( + self, + ): + # Get common parameters from first expert + first_expert = self.experts[0] + self.bits = first_expert.gate_proj.bits + self.group_size = first_expert.gate_proj.group_size + self.act_fn = first_expert.act_fn + assert first_expert.gate_proj.act_order == first_expert.up_proj.act_order == first_expert.down_proj.act_order, ( + "act_order mismatch" + ) + self.act_order = first_expert.gate_proj.act_order + + # Store dimensions for dequantization + self.in_features_gate, self.out_features_gate = ( + first_expert.gate_proj.in_features, + first_expert.gate_proj.out_features, + ) + self.in_features_up, self.out_features_up = first_expert.up_proj.in_features, first_expert.up_proj.out_features + self.in_features_down, self.out_features_down = ( + first_expert.down_proj.in_features, + first_expert.down_proj.out_features, + ) + + # Stack all parameters along a new dimension (expert dimension) + self.all_gate_qweight = torch.nn.Parameter( + torch.stack([exp.gate_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // 2 + ), + requires_grad=False, + ) + self.all_gate_scales = torch.nn.Parameter( + torch.stack([exp.gate_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // self.group_size + ), + requires_grad=False, + ) + # TODO: Since we know qzeros is always 8 -> Just embed this once into the operator as parameter -> explore this later + self.all_gate_qzeros = torch.nn.Parameter( + torch.stack([exp.gate_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_gate_gidx = torch.nn.Parameter( + torch.stack([exp.gate_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + self.all_up_qweight = torch.nn.Parameter( + torch.stack([exp.up_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // 2 + ), + requires_grad=False, + ) + self.all_up_scales = torch.nn.Parameter( + torch.stack([exp.up_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // self.group_size + ), + requires_grad=False, + ) + self.all_up_qzeros = torch.nn.Parameter( + torch.stack([exp.up_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_up_gidx = torch.nn.Parameter( + torch.stack([exp.up_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + self.all_down_qweight = torch.nn.Parameter( + torch.stack([exp.down_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // 2 + ), + requires_grad=False, + ) + self.all_down_scales = torch.nn.Parameter( + torch.stack([exp.down_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // self.group_size + ), + requires_grad=False, + ) + self.all_down_qzeros = torch.nn.Parameter( + torch.stack([exp.down_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_down_gidx = torch.nn.Parameter( + torch.stack([exp.down_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + def _cumsum_scatter_gather_update_expert_blocked( self, x: torch.Tensor, T2Ei: torch.Tensor, - expert, - # W_g: torch.Tensor, - # W_u: torch.Tensor, - # W_d: torch.Tensor, + slot_gate_qweight: torch.Tensor, + slot_gate_scales: torch.Tensor, + slot_gate_qzeros: torch.Tensor, + slot_up_qweight: torch.Tensor, + slot_up_scales: torch.Tensor, + slot_up_qzeros: torch.Tensor, + slot_down_qweight: torch.Tensor, + slot_down_scales: torch.Tensor, + slot_down_qzeros: torch.Tensor, routing_weight: torch.Tensor, expert_out: torch.Tensor, - act_fn, - T: int, packed_chunk_size: int, ) -> torch.Tensor: """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. @@ -1169,16 +1264,32 @@ def _cumsum_scatter_gather_update_expert_blocked( x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) - gate_prime = expert.gate_proj(x_chunk) - up_prime = expert.up_proj(x_chunk) - down_chunk = expert.down_proj((up_prime * act_fn(gate_prime))) + gate_proj_unpacked = CastToUInt4Func.apply(slot_gate_qweight) + gate_zeros_unpacked = CastToUInt4Func.apply(slot_gate_qzeros) + gate_proj_dq = DequantizeLinearFunc.apply( + gate_proj_unpacked, slot_gate_scales, gate_zeros_unpacked, self.group_size + ) + + up_proj_unpacked = CastToUInt4Func.apply(slot_up_qweight) + up_zeros_unpacked = CastToUInt4Func.apply(slot_up_qzeros) + up_proj_dq = DequantizeLinearFunc.apply( + up_proj_unpacked, slot_up_scales, up_zeros_unpacked, self.group_size + ) - # gate_prime = x_chunk @ W_g - # up_prime = x_chunk @ W_u - # down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + down_proj_unpacked = CastToUInt4Func.apply(slot_down_qweight) + down_zeros_unpacked = CastToUInt4Func.apply(slot_down_qzeros) + + down_proj_dq = DequantizeLinearFunc.apply( + down_proj_unpacked, slot_down_scales, down_zeros_unpacked, self.group_size + ) + + gate_out = torch.bmm(x_chunk, gate_proj_dq) + up_out = torch.bmm(x_chunk, up_proj_dq) + hidden = self.act_fn(gate_out) * up_out + down_out = torch.bmm(hidden, down_proj_dq.transpose(1, 2)) rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) - down_chunk = down_chunk * rw_chunk + down_chunk = down_out * rw_chunk expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk @@ -1200,30 +1311,44 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor ) local_experts = len(self.experts) // num_nsp rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - # W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - # W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - # W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) + + local_gate_qweight = self.all_gate_qweight.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_gate_scales = self.all_gate_scales.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_gate_qzeros = self.all_gate_qzeros.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + + local_up_qweight = self.all_up_qweight.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_up_scales = self.all_up_scales.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_up_qzeros = self.all_up_qzeros.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + + local_down_qweight = self.all_down_qweight.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_down_scales = self.all_down_scales.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_down_qzeros = self.all_down_qzeros.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + for slot in range(local_experts): routing_weight = rw[:, slot, :] T2Ei = routing_weight > 0 expert_out = self._cumsum_scatter_gather_update_expert_blocked( x=x, T2Ei=T2Ei, - expert=self.experts[slot], - # W_g=W_g[:, slot], - # W_u=W_u[:, slot], - # W_d=W_d[:, slot], + slot_gate_qweight=local_gate_qweight[:, slot], + slot_gate_scales=local_gate_scales[:, slot], + slot_gate_qzeros=local_gate_qzeros[:, slot], + slot_up_qweight=local_up_qweight[:, slot], + slot_up_scales=local_up_scales[:, slot], + slot_up_qzeros=local_up_qzeros[:, slot], + slot_down_qweight=local_down_qweight[:, slot], + slot_down_scales=local_down_scales[:, slot], + slot_down_qzeros=local_down_qzeros[:, slot], routing_weight=routing_weight, expert_out=expert_out, - act_fn=self.experts[0].act_fn, - T=T, packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) return expert_out.sum(dim=0) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - topk_idx, topk_weight, router_probs, router_weights = self.gate(hidden_states) + topk_idx, topk_weight, _, _ = self.gate(hidden_states) B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) @@ -1236,15 +1361,16 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return expert_out.view(B, S, H) final_hidden_states = x.new_zeros((T, H)) - for expert_idx in range(self.n_routed_experts): + for expert_idx in range(self.config.n_routed_experts): expert = self.experts[expert_idx] gate_out = expert.gate_proj(hidden_states) up_out = expert.up_proj(hidden_states) hidden = expert.act_fn(gate_out) * up_out expert_output = expert.down_proj(hidden) current_hidden_states = expert_output * routing_weights[:, expert_idx].unsqueeze(-1) - final_hidden_states += current_hidden_states + final_hidden_states = final_hidden_states.view(B, S, H) + current_hidden_states + final_hidden_states = final_hidden_states + self.shared_experts(hidden_states) return final_hidden_states.view(B, S, H) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index a4b93a4812..3792104a04 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1062,8 +1062,7 @@ class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_string_replace_method = { "DeepseekV3MoE": { "forward": QEffPrefillOnlyDeepseekV3MoE.forward, - # "moe": QEffPrefillOnlyDeepseekV3MoE.moe, - # "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, "_forward_expert_blocked": QEffPrefillOnlyDeepseekV3MoE._forward_expert_blocked, "_cumsum_scatter_gather_update_expert_blocked": QEffPrefillOnlyDeepseekV3MoE._cumsum_scatter_gather_update_expert_blocked, }, diff --git a/examples/kimi_k2/export_kimi_k25.py b/examples/kimi_k2/export_kimi_k25.py new file mode 100644 index 0000000000..5e480abbf9 --- /dev/null +++ b/examples/kimi_k2/export_kimi_k25.py @@ -0,0 +1,238 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import json +import re +import tempfile +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig, AutoTokenizer +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from QEfficient import QEFFAutoModelForCausalLM + +# parameters to be configured +prompt = "Once upon a time," +num_hidden_layers = 2 +TS = 4 +mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} +# qaic_config = None # Full PKV Cache +# qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking +# qaic_config = {"mla_absorption": mla_absorption} # for No Blocking +# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +qaic_config = { + "mla_absorption": mla_absorption, + "enable_blocking": True, + "blocking_mode": "kv", + "num_kv_heads_repeat": TS, +} # for KV blocking with kv head replication +# qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} +# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + + +MODEL_PATH = Path( + "/home/huggingface_hub/models--moonshotai--Kimi-K2.5/snapshots/54383e83fa343a1331754112fb9e3410c55efa2f" +) +NUM_HIDDEN_LAYERS = 2 +LOADED_EXPERT_IDS = (0, 1, 2, 3) +NUM_EXPERTS_PER_TOKEN = 2 + +EXPERT_KEY_PATTERN = re.compile(r"^(language_model\.model\.layers\.\d+\.mlp\.experts\.)(\d+)(\..+)$") + + +def _validate_expert_subset(loaded_expert_ids, num_experts_per_tok, total_experts): + expert_ids = tuple(loaded_expert_ids) + if len(expert_ids) != 4: + raise ValueError(f"Expected exactly 4 routed experts, got {expert_ids!r}.") + if len(set(expert_ids)) != len(expert_ids): + raise ValueError(f"Expert ids must be unique, got {expert_ids!r}.") + invalid_ids = [expert_id for expert_id in expert_ids if expert_id < 0 or expert_id >= total_experts] + if invalid_ids: + raise ValueError(f"Expert ids {invalid_ids!r} are outside the valid range [0, {total_experts - 1}].") + if num_experts_per_tok > len(expert_ids): + raise ValueError(f"num_experts_per_tok={num_experts_per_tok} cannot exceed {len(expert_ids)} loaded experts.") + return expert_ids + + +def _remap_checkpoint_key(checkpoint_key, expert_index_map): + match = EXPERT_KEY_PATTERN.match(checkpoint_key) + if not match: + return checkpoint_key + + original_expert_idx = int(match.group(2)) + remapped_expert_idx = expert_index_map.get(original_expert_idx) + if remapped_expert_idx is None: + return None + return f"{match.group(1)}{remapped_expert_idx}{match.group(3)}" + + +def _is_routed_gate_weight(checkpoint_key): + return checkpoint_key.endswith(".mlp.gate.weight") + + +def _is_routed_gate_bias(checkpoint_key): + return checkpoint_key.endswith(".mlp.gate.e_score_correction_bias") + + +def _materialize_subset_checkpoint( + model_path: Path, + temp_model_path: Path, + weight_map, + allowed_prefixes, + loaded_expert_ids, +): + expert_index_map = {expert_id: remapped_idx for remapped_idx, expert_id in enumerate(loaded_expert_ids)} + selected_by_shard = defaultdict(list) + + for checkpoint_key, shard_name in weight_map.items(): + if not any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes): + continue + + remapped_key = _remap_checkpoint_key(checkpoint_key, expert_index_map) + if remapped_key is None: + continue + selected_by_shard[shard_name].append((checkpoint_key, remapped_key)) + + if not selected_by_shard: + raise RuntimeError("No text-only weights were selected from the Kimi K2.5 checkpoint.") + + filtered_weight_map = {} + subset_shards = [] + for shard_idx, (source_shard_name, shard_entries) in enumerate(sorted(selected_by_shard.items())): + tensors = {} + with safe_open(model_path / source_shard_name, framework="pt", device="cpu") as shard_reader: + for checkpoint_key, remapped_key in shard_entries: + tensor = shard_reader.get_tensor(checkpoint_key) + if _is_routed_gate_weight(checkpoint_key): + tensor = tensor[list(loaded_expert_ids), :].contiguous() + elif _is_routed_gate_bias(checkpoint_key): + tensor = tensor[list(loaded_expert_ids)].contiguous() + tensors[remapped_key] = tensor + + subset_shard_name = f"model-subset-{shard_idx:05d}.safetensors" + save_file(tensors, str(temp_model_path / subset_shard_name)) + subset_shards.append(subset_shard_name) + filtered_weight_map.update({remapped_key: subset_shard_name for _, remapped_key in shard_entries}) + + return filtered_weight_map, subset_shards + + +def load_text_only_kimi( + model_path: Path, + num_hidden_layers: int, + loaded_expert_ids=LOADED_EXPERT_IDS, + num_experts_per_tok: int = NUM_EXPERTS_PER_TOKEN, +): + kimi_config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) + + # Kimi K2.5 is multimodal, so the text depth must be overridden on text_config. + text_config = copy.deepcopy(kimi_config.text_config) + text_config.num_hidden_layers = num_hidden_layers + loaded_expert_ids = _validate_expert_subset( + loaded_expert_ids, + num_experts_per_tok, + text_config.n_routed_experts, + ) + text_config.n_routed_experts = len(loaded_expert_ids) + text_config.num_experts_per_tok = num_experts_per_tok + text_config.n_group = 1 + text_config.topk_group = 1 + + deepseek_cls = get_class_from_dynamic_module("modeling_deepseek.DeepseekV3ForCausalLM", str(model_path)) + + checkpoint_index = json.loads((model_path / "model.safetensors.index.json").read_text()) + weight_map = checkpoint_index["weight_map"] + + allowed_prefixes = [ + "language_model.model.embed_tokens.", + "language_model.model.norm.", + "language_model.lm_head.", + ] + allowed_prefixes.extend(f"language_model.model.layers.{layer_idx}." for layer_idx in range(num_hidden_layers)) + + with tempfile.TemporaryDirectory() as tmpdir: + temp_model_path = Path(tmpdir) + filtered_weight_map, subset_shards = _materialize_subset_checkpoint( + model_path=model_path, + temp_model_path=temp_model_path, + weight_map=weight_map, + allowed_prefixes=allowed_prefixes, + loaded_expert_ids=loaded_expert_ids, + ) + (temp_model_path / "config.json").write_text(text_config.to_json_string(use_diff=False)) + (temp_model_path / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": { + "total_size": sum((temp_model_path / shard_name).stat().st_size for shard_name in subset_shards) + }, + "weight_map": filtered_weight_map, + } + ) + ) + + # We are loading a task checkpoint into the base text model, so disable the + # base/task prefix heuristic and let `key_mapping` strip `language_model.`. + original_base_model_prefix = deepseek_cls.base_model_prefix + deepseek_cls.base_model_prefix = "" + try: + model, loading_info = deepseek_cls.from_pretrained( + str(temp_model_path), + torch_dtype=torch.float32, + config=text_config, + local_files_only=True, + key_mapping={r"^language_model\.": ""}, + output_loading_info=True, + ) + finally: + deepseek_cls.base_model_prefix = original_base_model_prefix + + unexpected_keys = loading_info["unexpected_keys"] + missing_keys = loading_info["missing_keys"] + mismatched_keys = loading_info["mismatched_keys"] + if unexpected_keys or missing_keys or mismatched_keys: + raise RuntimeError( + "Failed to load the text-only Kimi K2.5 checkpoint slice cleanly. " + f"missing={missing_keys}, unexpected={unexpected_keys}, mismatched={mismatched_keys}" + ) + + model.eval() + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + return model, tokenizer + + +model, tokenizer = load_text_only_kimi( + MODEL_PATH, + NUM_HIDDEN_LAYERS, + loaded_expert_ids=LOADED_EXPERT_IDS, + num_experts_per_tok=NUM_EXPERTS_PER_TOKEN, +) + +qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) + +onnx_path = qeff_model.export(prefill_seq_len=1, qaic_config=qaic_config) + +prefill_seq_len = 1 +ctx_len = 1024 + +qpc_path = qeff_model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=TS, + num_cores=16, + qaic_config=qaic_config, +) + +qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 03ff29652c7fc509d9f07ffe4d9f0c74f4377ee5 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 May 2026 21:41:54 +0530 Subject: [PATCH 04/17] fixed subfunction compilation Signed-off-by: Onkar Chougule --- .../transformers/models/deepseek_v3/modeling_deepseek.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 2169f2bf0d..5ba3e54773 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1102,10 +1102,10 @@ def moe_waa_unpack(self, hidden_states, topk_indices, topk_weights): hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.in_features_gate) ) - gate_out = torch.bmm(expert_in, gate_proj_dq.transpose(1, 2)) - up_out = torch.bmm(expert_in, up_proj_dq.transpose(1, 2)) + gate_out = torch.bmm(expert_in, gate_proj_dq.transpose(1, 2).to(expert_in.dtype)) + up_out = torch.bmm(expert_in, up_proj_dq.transpose(1, 2).to(expert_in.dtype)) hidden = self.act_fn(gate_out) * up_out - down_out = torch.bmm(hidden, down_proj_dq.transpose(1, 2)) + down_out = torch.bmm(hidden, down_proj_dq.transpose(1, 2).to(expert_in.dtype)) down_out = down_out.view(-1, self.gate.top_k, self.out_features_down) From 488ac8d66562da31a3e24d433e606a1478e93818 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 8 May 2026 14:29:13 +0530 Subject: [PATCH 05/17] update kv blocking Signed-off-by: Mamta Singh --- QEfficient/blocking/blocked_attention_forwards.py | 6 ++++++ .../models/deepseek_v3/modeling_deepseek.py | 13 +++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index ff35628ee9..6aed6e49f9 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -931,6 +931,12 @@ def blocked_kv_mla_attention_forward( ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] else: knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) + if k_heads == 1: + k_pe_block = ( + k_pe_block.unsqueeze(1) + .expand(-1, num_heads, -1, -1, -1) + .reshape(batch_size, num_heads, -1, module.config.qk_rope_head_dim) + ) krope_nope = torch.cat((knope, k_pe_block), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 5ba3e54773..02c47f3d19 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -477,12 +477,7 @@ def fused_forward_orig( k_pe_expanded = k_pe_expanded[:, :q_heads, :, :] else: kva_expanded = kva - num_heads_to_repeat = math.ceil(q_heads / k_heads) - k_pe_expanded = ( - k_pe.unsqueeze(2) - .expand(-1, -1, num_heads_to_repeat, -1, -1) - .reshape(bsz, num_heads_to_repeat * k_heads, -1, self.config.qk_rope_head_dim) - ) + k_pe_expanded = k_pe v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) value_states = torch.matmul(kva_expanded, v_up_per_head) @@ -507,6 +502,12 @@ def fused_forward_orig( self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) ) k_nope = torch.matmul(kva_expanded, k_up_per_head) + if k_heads == 1: + k_pe_expanded = ( + k_pe_expanded.unsqueeze(1) + .expand(-1, self.num_heads, -1, -1, -1) + .reshape(bsz, self.num_heads, -1, self.qk_rope_head_dim) + ) key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale From f2c89d45241efdb3f8d2a0a828796d2c8a41895f Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 8 May 2026 20:33:22 +0530 Subject: [PATCH 06/17] add example script for kimi-k2.5 Signed-off-by: Mamta Singh --- examples/kimi_k2/export_kimi_k25.py | 136 ++-------- .../kimi_k2/export_kimi_k25_4_experts_only.py | 238 ++++++++++++++++++ 2 files changed, 258 insertions(+), 116 deletions(-) create mode 100644 examples/kimi_k2/export_kimi_k25_4_experts_only.py diff --git a/examples/kimi_k2/export_kimi_k25.py b/examples/kimi_k2/export_kimi_k25.py index 5e480abbf9..8b97e94e21 100644 --- a/examples/kimi_k2/export_kimi_k25.py +++ b/examples/kimi_k2/export_kimi_k25.py @@ -7,14 +7,10 @@ import copy import json -import re import tempfile -from collections import defaultdict from pathlib import Path import torch -from safetensors import safe_open -from safetensors.torch import save_file from transformers import AutoConfig, AutoTokenizer from transformers.dynamic_module_utils import get_class_from_dynamic_module @@ -22,7 +18,6 @@ # parameters to be configured prompt = "Once upon a time," -num_hidden_layers = 2 TS = 4 mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} # qaic_config = None # Full PKV Cache @@ -39,114 +34,18 @@ # qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat - MODEL_PATH = Path( "/home/huggingface_hub/models--moonshotai--Kimi-K2.5/snapshots/54383e83fa343a1331754112fb9e3410c55efa2f" ) NUM_HIDDEN_LAYERS = 2 -LOADED_EXPERT_IDS = (0, 1, 2, 3) -NUM_EXPERTS_PER_TOKEN = 2 - -EXPERT_KEY_PATTERN = re.compile(r"^(language_model\.model\.layers\.\d+\.mlp\.experts\.)(\d+)(\..+)$") - - -def _validate_expert_subset(loaded_expert_ids, num_experts_per_tok, total_experts): - expert_ids = tuple(loaded_expert_ids) - if len(expert_ids) != 4: - raise ValueError(f"Expected exactly 4 routed experts, got {expert_ids!r}.") - if len(set(expert_ids)) != len(expert_ids): - raise ValueError(f"Expert ids must be unique, got {expert_ids!r}.") - invalid_ids = [expert_id for expert_id in expert_ids if expert_id < 0 or expert_id >= total_experts] - if invalid_ids: - raise ValueError(f"Expert ids {invalid_ids!r} are outside the valid range [0, {total_experts - 1}].") - if num_experts_per_tok > len(expert_ids): - raise ValueError(f"num_experts_per_tok={num_experts_per_tok} cannot exceed {len(expert_ids)} loaded experts.") - return expert_ids - - -def _remap_checkpoint_key(checkpoint_key, expert_index_map): - match = EXPERT_KEY_PATTERN.match(checkpoint_key) - if not match: - return checkpoint_key - - original_expert_idx = int(match.group(2)) - remapped_expert_idx = expert_index_map.get(original_expert_idx) - if remapped_expert_idx is None: - return None - return f"{match.group(1)}{remapped_expert_idx}{match.group(3)}" - - -def _is_routed_gate_weight(checkpoint_key): - return checkpoint_key.endswith(".mlp.gate.weight") - - -def _is_routed_gate_bias(checkpoint_key): - return checkpoint_key.endswith(".mlp.gate.e_score_correction_bias") -def _materialize_subset_checkpoint( - model_path: Path, - temp_model_path: Path, - weight_map, - allowed_prefixes, - loaded_expert_ids, -): - expert_index_map = {expert_id: remapped_idx for remapped_idx, expert_id in enumerate(loaded_expert_ids)} - selected_by_shard = defaultdict(list) - - for checkpoint_key, shard_name in weight_map.items(): - if not any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes): - continue - - remapped_key = _remap_checkpoint_key(checkpoint_key, expert_index_map) - if remapped_key is None: - continue - selected_by_shard[shard_name].append((checkpoint_key, remapped_key)) - - if not selected_by_shard: - raise RuntimeError("No text-only weights were selected from the Kimi K2.5 checkpoint.") - - filtered_weight_map = {} - subset_shards = [] - for shard_idx, (source_shard_name, shard_entries) in enumerate(sorted(selected_by_shard.items())): - tensors = {} - with safe_open(model_path / source_shard_name, framework="pt", device="cpu") as shard_reader: - for checkpoint_key, remapped_key in shard_entries: - tensor = shard_reader.get_tensor(checkpoint_key) - if _is_routed_gate_weight(checkpoint_key): - tensor = tensor[list(loaded_expert_ids), :].contiguous() - elif _is_routed_gate_bias(checkpoint_key): - tensor = tensor[list(loaded_expert_ids)].contiguous() - tensors[remapped_key] = tensor - - subset_shard_name = f"model-subset-{shard_idx:05d}.safetensors" - save_file(tensors, str(temp_model_path / subset_shard_name)) - subset_shards.append(subset_shard_name) - filtered_weight_map.update({remapped_key: subset_shard_name for _, remapped_key in shard_entries}) - - return filtered_weight_map, subset_shards - - -def load_text_only_kimi( - model_path: Path, - num_hidden_layers: int, - loaded_expert_ids=LOADED_EXPERT_IDS, - num_experts_per_tok: int = NUM_EXPERTS_PER_TOKEN, -): +def load_text_only_kimi(model_path: Path, num_hidden_layers: int): kimi_config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) # Kimi K2.5 is multimodal, so the text depth must be overridden on text_config. text_config = copy.deepcopy(kimi_config.text_config) text_config.num_hidden_layers = num_hidden_layers - loaded_expert_ids = _validate_expert_subset( - loaded_expert_ids, - num_experts_per_tok, - text_config.n_routed_experts, - ) - text_config.n_routed_experts = len(loaded_expert_ids) - text_config.num_experts_per_tok = num_experts_per_tok - text_config.n_group = 1 - text_config.topk_group = 1 deepseek_cls = get_class_from_dynamic_module("modeling_deepseek.DeepseekV3ForCausalLM", str(model_path)) @@ -160,26 +59,36 @@ def load_text_only_kimi( ] allowed_prefixes.extend(f"language_model.model.layers.{layer_idx}." for layer_idx in range(num_hidden_layers)) + required_shards = sorted( + { + shard_name + for checkpoint_key, shard_name in weight_map.items() + if any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes) + } + ) + filtered_weight_map = { + checkpoint_key: shard_name + for checkpoint_key, shard_name in weight_map.items() + if any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes) + } + if not filtered_weight_map: + raise RuntimeError("No text-only weights were selected from the Kimi K2.5 checkpoint.") + with tempfile.TemporaryDirectory() as tmpdir: temp_model_path = Path(tmpdir) - filtered_weight_map, subset_shards = _materialize_subset_checkpoint( - model_path=model_path, - temp_model_path=temp_model_path, - weight_map=weight_map, - allowed_prefixes=allowed_prefixes, - loaded_expert_ids=loaded_expert_ids, - ) (temp_model_path / "config.json").write_text(text_config.to_json_string(use_diff=False)) (temp_model_path / "model.safetensors.index.json").write_text( json.dumps( { "metadata": { - "total_size": sum((temp_model_path / shard_name).stat().st_size for shard_name in subset_shards) + "total_size": sum((model_path / shard_name).stat().st_size for shard_name in required_shards) }, "weight_map": filtered_weight_map, } ) ) + for shard_name in required_shards: + (temp_model_path / shard_name).symlink_to(model_path / shard_name) # We are loading a task checkpoint into the base text model, so disable the # base/task prefix heuristic and let `key_mapping` strip `language_model.`. @@ -211,12 +120,7 @@ def load_text_only_kimi( return model, tokenizer -model, tokenizer = load_text_only_kimi( - MODEL_PATH, - NUM_HIDDEN_LAYERS, - loaded_expert_ids=LOADED_EXPERT_IDS, - num_experts_per_tok=NUM_EXPERTS_PER_TOKEN, -) +model, tokenizer = load_text_only_kimi(MODEL_PATH, NUM_HIDDEN_LAYERS) qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) diff --git a/examples/kimi_k2/export_kimi_k25_4_experts_only.py b/examples/kimi_k2/export_kimi_k25_4_experts_only.py new file mode 100644 index 0000000000..5e480abbf9 --- /dev/null +++ b/examples/kimi_k2/export_kimi_k25_4_experts_only.py @@ -0,0 +1,238 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import json +import re +import tempfile +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig, AutoTokenizer +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from QEfficient import QEFFAutoModelForCausalLM + +# parameters to be configured +prompt = "Once upon a time," +num_hidden_layers = 2 +TS = 4 +mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} +# qaic_config = None # Full PKV Cache +# qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking +# qaic_config = {"mla_absorption": mla_absorption} # for No Blocking +# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +qaic_config = { + "mla_absorption": mla_absorption, + "enable_blocking": True, + "blocking_mode": "kv", + "num_kv_heads_repeat": TS, +} # for KV blocking with kv head replication +# qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} +# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + + +MODEL_PATH = Path( + "/home/huggingface_hub/models--moonshotai--Kimi-K2.5/snapshots/54383e83fa343a1331754112fb9e3410c55efa2f" +) +NUM_HIDDEN_LAYERS = 2 +LOADED_EXPERT_IDS = (0, 1, 2, 3) +NUM_EXPERTS_PER_TOKEN = 2 + +EXPERT_KEY_PATTERN = re.compile(r"^(language_model\.model\.layers\.\d+\.mlp\.experts\.)(\d+)(\..+)$") + + +def _validate_expert_subset(loaded_expert_ids, num_experts_per_tok, total_experts): + expert_ids = tuple(loaded_expert_ids) + if len(expert_ids) != 4: + raise ValueError(f"Expected exactly 4 routed experts, got {expert_ids!r}.") + if len(set(expert_ids)) != len(expert_ids): + raise ValueError(f"Expert ids must be unique, got {expert_ids!r}.") + invalid_ids = [expert_id for expert_id in expert_ids if expert_id < 0 or expert_id >= total_experts] + if invalid_ids: + raise ValueError(f"Expert ids {invalid_ids!r} are outside the valid range [0, {total_experts - 1}].") + if num_experts_per_tok > len(expert_ids): + raise ValueError(f"num_experts_per_tok={num_experts_per_tok} cannot exceed {len(expert_ids)} loaded experts.") + return expert_ids + + +def _remap_checkpoint_key(checkpoint_key, expert_index_map): + match = EXPERT_KEY_PATTERN.match(checkpoint_key) + if not match: + return checkpoint_key + + original_expert_idx = int(match.group(2)) + remapped_expert_idx = expert_index_map.get(original_expert_idx) + if remapped_expert_idx is None: + return None + return f"{match.group(1)}{remapped_expert_idx}{match.group(3)}" + + +def _is_routed_gate_weight(checkpoint_key): + return checkpoint_key.endswith(".mlp.gate.weight") + + +def _is_routed_gate_bias(checkpoint_key): + return checkpoint_key.endswith(".mlp.gate.e_score_correction_bias") + + +def _materialize_subset_checkpoint( + model_path: Path, + temp_model_path: Path, + weight_map, + allowed_prefixes, + loaded_expert_ids, +): + expert_index_map = {expert_id: remapped_idx for remapped_idx, expert_id in enumerate(loaded_expert_ids)} + selected_by_shard = defaultdict(list) + + for checkpoint_key, shard_name in weight_map.items(): + if not any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes): + continue + + remapped_key = _remap_checkpoint_key(checkpoint_key, expert_index_map) + if remapped_key is None: + continue + selected_by_shard[shard_name].append((checkpoint_key, remapped_key)) + + if not selected_by_shard: + raise RuntimeError("No text-only weights were selected from the Kimi K2.5 checkpoint.") + + filtered_weight_map = {} + subset_shards = [] + for shard_idx, (source_shard_name, shard_entries) in enumerate(sorted(selected_by_shard.items())): + tensors = {} + with safe_open(model_path / source_shard_name, framework="pt", device="cpu") as shard_reader: + for checkpoint_key, remapped_key in shard_entries: + tensor = shard_reader.get_tensor(checkpoint_key) + if _is_routed_gate_weight(checkpoint_key): + tensor = tensor[list(loaded_expert_ids), :].contiguous() + elif _is_routed_gate_bias(checkpoint_key): + tensor = tensor[list(loaded_expert_ids)].contiguous() + tensors[remapped_key] = tensor + + subset_shard_name = f"model-subset-{shard_idx:05d}.safetensors" + save_file(tensors, str(temp_model_path / subset_shard_name)) + subset_shards.append(subset_shard_name) + filtered_weight_map.update({remapped_key: subset_shard_name for _, remapped_key in shard_entries}) + + return filtered_weight_map, subset_shards + + +def load_text_only_kimi( + model_path: Path, + num_hidden_layers: int, + loaded_expert_ids=LOADED_EXPERT_IDS, + num_experts_per_tok: int = NUM_EXPERTS_PER_TOKEN, +): + kimi_config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) + + # Kimi K2.5 is multimodal, so the text depth must be overridden on text_config. + text_config = copy.deepcopy(kimi_config.text_config) + text_config.num_hidden_layers = num_hidden_layers + loaded_expert_ids = _validate_expert_subset( + loaded_expert_ids, + num_experts_per_tok, + text_config.n_routed_experts, + ) + text_config.n_routed_experts = len(loaded_expert_ids) + text_config.num_experts_per_tok = num_experts_per_tok + text_config.n_group = 1 + text_config.topk_group = 1 + + deepseek_cls = get_class_from_dynamic_module("modeling_deepseek.DeepseekV3ForCausalLM", str(model_path)) + + checkpoint_index = json.loads((model_path / "model.safetensors.index.json").read_text()) + weight_map = checkpoint_index["weight_map"] + + allowed_prefixes = [ + "language_model.model.embed_tokens.", + "language_model.model.norm.", + "language_model.lm_head.", + ] + allowed_prefixes.extend(f"language_model.model.layers.{layer_idx}." for layer_idx in range(num_hidden_layers)) + + with tempfile.TemporaryDirectory() as tmpdir: + temp_model_path = Path(tmpdir) + filtered_weight_map, subset_shards = _materialize_subset_checkpoint( + model_path=model_path, + temp_model_path=temp_model_path, + weight_map=weight_map, + allowed_prefixes=allowed_prefixes, + loaded_expert_ids=loaded_expert_ids, + ) + (temp_model_path / "config.json").write_text(text_config.to_json_string(use_diff=False)) + (temp_model_path / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": { + "total_size": sum((temp_model_path / shard_name).stat().st_size for shard_name in subset_shards) + }, + "weight_map": filtered_weight_map, + } + ) + ) + + # We are loading a task checkpoint into the base text model, so disable the + # base/task prefix heuristic and let `key_mapping` strip `language_model.`. + original_base_model_prefix = deepseek_cls.base_model_prefix + deepseek_cls.base_model_prefix = "" + try: + model, loading_info = deepseek_cls.from_pretrained( + str(temp_model_path), + torch_dtype=torch.float32, + config=text_config, + local_files_only=True, + key_mapping={r"^language_model\.": ""}, + output_loading_info=True, + ) + finally: + deepseek_cls.base_model_prefix = original_base_model_prefix + + unexpected_keys = loading_info["unexpected_keys"] + missing_keys = loading_info["missing_keys"] + mismatched_keys = loading_info["mismatched_keys"] + if unexpected_keys or missing_keys or mismatched_keys: + raise RuntimeError( + "Failed to load the text-only Kimi K2.5 checkpoint slice cleanly. " + f"missing={missing_keys}, unexpected={unexpected_keys}, mismatched={mismatched_keys}" + ) + + model.eval() + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + return model, tokenizer + + +model, tokenizer = load_text_only_kimi( + MODEL_PATH, + NUM_HIDDEN_LAYERS, + loaded_expert_ids=LOADED_EXPERT_IDS, + num_experts_per_tok=NUM_EXPERTS_PER_TOKEN, +) + +qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) + +onnx_path = qeff_model.export(prefill_seq_len=1, qaic_config=qaic_config) + +prefill_seq_len = 1 +ctx_len = 1024 + +qpc_path = qeff_model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=TS, + num_cores=16, + qaic_config=qaic_config, +) + +qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From cf43bc9f9c2a13ab8107b2296da7fc7793e09ee4 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Sun, 10 May 2026 12:54:42 +0530 Subject: [PATCH 07/17] remove redundant export calls Signed-off-by: Mamta Singh --- examples/kimi_k2/export_kimi_k25.py | 2 -- examples/kimi_k2/export_kimi_k25_4_experts_only.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/examples/kimi_k2/export_kimi_k25.py b/examples/kimi_k2/export_kimi_k25.py index 8b97e94e21..379c9645f8 100644 --- a/examples/kimi_k2/export_kimi_k25.py +++ b/examples/kimi_k2/export_kimi_k25.py @@ -124,8 +124,6 @@ def load_text_only_kimi(model_path: Path, num_hidden_layers: int): qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) -onnx_path = qeff_model.export(prefill_seq_len=1, qaic_config=qaic_config) - prefill_seq_len = 1 ctx_len = 1024 diff --git a/examples/kimi_k2/export_kimi_k25_4_experts_only.py b/examples/kimi_k2/export_kimi_k25_4_experts_only.py index 5e480abbf9..26f964ba36 100644 --- a/examples/kimi_k2/export_kimi_k25_4_experts_only.py +++ b/examples/kimi_k2/export_kimi_k25_4_experts_only.py @@ -220,8 +220,6 @@ def load_text_only_kimi( qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) -onnx_path = qeff_model.export(prefill_seq_len=1, qaic_config=qaic_config) - prefill_seq_len = 1 ctx_len = 1024 From 5aa8bca691e8ed801fead06f059222d3c0af8a48 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 13 May 2026 18:09:41 +0530 Subject: [PATCH 08/17] fix prefill output Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 02c47f3d19..98b032a9f2 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1284,8 +1284,8 @@ def _cumsum_scatter_gather_update_expert_blocked( down_proj_unpacked, slot_down_scales, down_zeros_unpacked, self.group_size ) - gate_out = torch.bmm(x_chunk, gate_proj_dq) - up_out = torch.bmm(x_chunk, up_proj_dq) + gate_out = torch.bmm(x_chunk, gate_proj_dq.transpose(1, 2)) + up_out = torch.bmm(x_chunk, up_proj_dq.transpose(1, 2)) hidden = self.act_fn(gate_out) * up_out down_out = torch.bmm(hidden, down_proj_dq.transpose(1, 2)) @@ -1315,17 +1315,17 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor expert_out = x.new_zeros((num_nsp, T, H)) - local_gate_qweight = self.all_gate_qweight.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - local_gate_scales = self.all_gate_scales.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - local_gate_qzeros = self.all_gate_qzeros.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_gate_qweight = self.all_gate_qweight.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // 2).transpose(0, 1).contiguous() + local_gate_scales = self.all_gate_scales.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // self.group_size).transpose(0, 1).contiguous() + local_gate_qzeros = self.all_gate_qzeros.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // (self.group_size * 2)).transpose(0, 1).contiguous() - local_up_qweight = self.all_up_qweight.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - local_up_scales = self.all_up_scales.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - local_up_qzeros = self.all_up_qzeros.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_up_qweight = self.all_up_qweight.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // 2).transpose(0, 1).contiguous() + local_up_scales = self.all_up_scales.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // self.group_size).transpose(0, 1).contiguous() + local_up_qzeros = self.all_up_qzeros.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // (self.group_size * 2)).transpose(0, 1).contiguous() - local_down_qweight = self.all_down_qweight.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - local_down_scales = self.all_down_scales.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - local_down_qzeros = self.all_down_qzeros.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + local_down_qweight = self.all_down_qweight.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // 2).transpose(0, 1).contiguous() + local_down_scales = self.all_down_scales.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // self.group_size).transpose(0, 1).contiguous() + local_down_qzeros = self.all_down_qzeros.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // (self.group_size * 2)).transpose(0, 1).contiguous() for slot in range(local_experts): routing_weight = rw[:, slot, :] @@ -1359,7 +1359,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens if len(self.experts) % EXPERT_BLOCKING_NUM_NSP == 0: expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) - return expert_out.view(B, S, H) + return expert_out.view(B, S, H) + self.shared_experts(hidden_states) final_hidden_states = x.new_zeros((T, H)) for expert_idx in range(self.config.n_routed_experts): From b353b2f77211f21b019f298e4cc4c854c42d3aee Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 13 May 2026 19:29:19 +0530 Subject: [PATCH 09/17] fixed EP Q chunking Signed-off-by: Onkar Chougule --- .../models/deepseek_v3/modeling_deepseek.py | 112 +++++++++++++++--- .../transformers/models/modeling_auto.py | 6 + QEfficient/utils/constants.py | 2 +- examples/kimi_k2/export_kimi_k25.py | 7 +- .../kimi_k2/export_kimi_k25_4_experts_only.py | 5 +- 5 files changed, 108 insertions(+), 24 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 98b032a9f2..7d88267216 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -763,7 +763,7 @@ def forward( ) -EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "4")) EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) @@ -1235,6 +1235,7 @@ def _cumsum_scatter_gather_update_expert_blocked( routing_weight: torch.Tensor, expert_out: torch.Tensor, packed_chunk_size: int, + num_q_ffn_blocks: Optional[int] = None, ) -> torch.Tensor: """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. @@ -1253,14 +1254,26 @@ def _cumsum_scatter_gather_update_expert_blocked( batch_size, seq_len = T2Ei.shape packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + if num_q_ffn_blocks is not None: + assert seq_len % num_q_ffn_blocks == 0, "Something went wrong" + packed_chunk_size = seq_len // num_q_ffn_blocks + else: + num_q_ffn_blocks = seq_len // packed_chunk_size + matched_idx = _build_matched_idx_from_cumsum(T2Ei) valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) rw_expanded = routing_weight.unsqueeze(-1) - for packed_start in range(0, seq_len, packed_chunk_size): - packed_stop = packed_start + packed_chunk_size + for chunk_idx in range(num_q_ffn_blocks): + print("executing chunk", chunk_idx) + packed_start = chunk_idx * packed_chunk_size + if chunk_idx == num_q_ffn_blocks - 1: + packed_stop = seq_len + else: + packed_stop = packed_start + packed_chunk_size + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) @@ -1303,7 +1316,9 @@ def _cumsum_scatter_gather_update_expert_blocked( return expert_out - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_q_ffn_blocks: Optional[int] = None + ) -> torch.Tensor: T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP if len(self.experts) % num_nsp != 0: @@ -1315,19 +1330,68 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor expert_out = x.new_zeros((num_nsp, T, H)) - local_gate_qweight = self.all_gate_qweight.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // 2).transpose(0, 1).contiguous() - local_gate_scales = self.all_gate_scales.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // self.group_size).transpose(0, 1).contiguous() - local_gate_qzeros = self.all_gate_qzeros.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // (self.group_size * 2)).transpose(0, 1).contiguous() + local_gate_qweight = ( + self.all_gate_qweight.view(local_experts, num_nsp, self.out_features_gate, self.in_features_gate // 2) + .transpose(0, 1) + .contiguous() + ) + local_gate_scales = ( + self.all_gate_scales.view( + local_experts, num_nsp, self.out_features_gate, self.in_features_gate // self.group_size + ) + .transpose(0, 1) + .contiguous() + ) + local_gate_qzeros = ( + self.all_gate_qzeros.view( + local_experts, num_nsp, self.out_features_gate, self.in_features_gate // (self.group_size * 2) + ) + .transpose(0, 1) + .contiguous() + ) - local_up_qweight = self.all_up_qweight.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // 2).transpose(0, 1).contiguous() - local_up_scales = self.all_up_scales.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // self.group_size).transpose(0, 1).contiguous() - local_up_qzeros = self.all_up_qzeros.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // (self.group_size * 2)).transpose(0, 1).contiguous() + local_up_qweight = ( + self.all_up_qweight.view(local_experts, num_nsp, self.out_features_up, self.in_features_up // 2) + .transpose(0, 1) + .contiguous() + ) + local_up_scales = ( + self.all_up_scales.view( + local_experts, num_nsp, self.out_features_up, self.in_features_up // self.group_size + ) + .transpose(0, 1) + .contiguous() + ) + local_up_qzeros = ( + self.all_up_qzeros.view( + local_experts, num_nsp, self.out_features_up, self.in_features_up // (self.group_size * 2) + ) + .transpose(0, 1) + .contiguous() + ) - local_down_qweight = self.all_down_qweight.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // 2).transpose(0, 1).contiguous() - local_down_scales = self.all_down_scales.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // self.group_size).transpose(0, 1).contiguous() - local_down_qzeros = self.all_down_qzeros.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // (self.group_size * 2)).transpose(0, 1).contiguous() + local_down_qweight = ( + self.all_down_qweight.view(local_experts, num_nsp, self.out_features_down, self.in_features_down // 2) + .transpose(0, 1) + .contiguous() + ) + local_down_scales = ( + self.all_down_scales.view( + local_experts, num_nsp, self.out_features_down, self.in_features_down // self.group_size + ) + .transpose(0, 1) + .contiguous() + ) + local_down_qzeros = ( + self.all_down_qzeros.view( + local_experts, num_nsp, self.out_features_down, self.in_features_down // (self.group_size * 2) + ) + .transpose(0, 1) + .contiguous() + ) for slot in range(local_experts): + print(f"executing slot {slot}") routing_weight = rw[:, slot, :] T2Ei = routing_weight > 0 expert_out = self._cumsum_scatter_gather_update_expert_blocked( @@ -1345,10 +1409,13 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor routing_weight=routing_weight, expert_out=expert_out, packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, + num_q_ffn_blocks=num_q_ffn_blocks, ) return expert_out.sum(dim=0) - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, hidden_states: torch.Tensor, num_q_ffn_blocks: Optional[int] = None + ) -> tuple[torch.Tensor, torch.Tensor]: topk_idx, topk_weight, _, _ = self.gate(hidden_states) B, S, H = hidden_states.shape T = B * S @@ -1358,8 +1425,10 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights.scatter_(1, topk_idx, topk_weight) if len(self.experts) % EXPERT_BLOCKING_NUM_NSP == 0: - expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) - return expert_out.view(B, S, H) + self.shared_experts(hidden_states) + expert_out = self._forward_expert_blocked( + x=x, routing_weights=routing_weights, num_q_ffn_blocks=num_q_ffn_blocks + ) + self.shared_experts(hidden_states) + return expert_out.view(B, S, H) final_hidden_states = x.new_zeros((T, H)) for expert_idx in range(self.config.n_routed_experts): @@ -1391,10 +1460,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, mla_absorption: Optional[Dict[str, bool]] = None, + num_q_ffn_blocks: Optional[int] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states orig_hidden_states = self.input_layernorm(hidden_states) + # setattr(self.mlp, "num_q_ffn_blocks", num_q_ffn_blocks) if mla_absorption is not None: cache_compressed = mla_absorption.get("cache_compressed", False) else: @@ -1431,7 +1502,11 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + + if num_q_ffn_blocks is not None and self.mlp.__class__.__name__ == "DeepseekV3MoE": + self.mlp(hidden_states, num_q_ffn_blocks) + else: + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -1530,7 +1605,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + num_q_ffn_blocks = getattr(self, "num_q_blocks_ffn", None) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1547,6 +1622,7 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, mla_absorption=mla_absorption, + num_q_ffn_blocks=num_q_ffn_blocks, **kwargs, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 50bb99e40f..f1b8b6e3cc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3236,6 +3236,12 @@ def export( qaic_config=self.model.qaic_config, ) + if prefill_only: + assert prefill_seq_len is not None, "prefill_seq_len must be provided when prefill_only is True" + num_q_blocks_ffn = prefill_seq_len // constants.EXPERT_BLOCKING_PACKED_CHUNK_SIZE + num_q_blocks_ffn = num_q_blocks_ffn if num_q_blocks_ffn > 0 else 1 + setattr(self.model.model, "num_q_blocks_ffn", num_q_blocks_ffn) + return self._export( example_inputs, output_names=output_names, diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 339e4f4dac..e7222787c5 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -103,7 +103,7 @@ def get_models_dir(): COMPILER = ["/opt/qti-aic/exec/qaic-compile", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100" ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 - +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) # InternVL constants # Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B INTERN_FEATURE_SIZE = 256 diff --git a/examples/kimi_k2/export_kimi_k25.py b/examples/kimi_k2/export_kimi_k25.py index 379c9645f8..a836f5245a 100644 --- a/examples/kimi_k2/export_kimi_k25.py +++ b/examples/kimi_k2/export_kimi_k25.py @@ -18,7 +18,7 @@ # parameters to be configured prompt = "Once upon a time," -TS = 4 +TS = 1 mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking @@ -29,7 +29,7 @@ "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", - "num_kv_heads_repeat": TS, + "num_kv_heads_repeat": 1, } # for KV blocking with kv head replication # qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat @@ -124,7 +124,7 @@ def load_text_only_kimi(model_path: Path, num_hidden_layers: int): qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) -prefill_seq_len = 1 +prefill_seq_len = 32 ctx_len = 1024 qpc_path = qeff_model.compile( @@ -135,6 +135,7 @@ def load_text_only_kimi(model_path: Path, num_hidden_layers: int): num_devices=TS, num_cores=16, qaic_config=qaic_config, + prefill_only=True ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/kimi_k2/export_kimi_k25_4_experts_only.py b/examples/kimi_k2/export_kimi_k25_4_experts_only.py index 26f964ba36..384d0f9da0 100644 --- a/examples/kimi_k2/export_kimi_k25_4_experts_only.py +++ b/examples/kimi_k2/export_kimi_k25_4_experts_only.py @@ -23,7 +23,7 @@ # parameters to be configured prompt = "Once upon a time," num_hidden_layers = 2 -TS = 4 +TS = 1 mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking @@ -220,7 +220,7 @@ def load_text_only_kimi( qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) -prefill_seq_len = 1 +prefill_seq_len = 32 ctx_len = 1024 qpc_path = qeff_model.compile( @@ -231,6 +231,7 @@ def load_text_only_kimi( num_devices=TS, num_cores=16, qaic_config=qaic_config, + prefill_only=True ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From b80eb76e41dd64ca5d5e1b695f6de2124ea0b455 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 13 May 2026 20:11:32 +0530 Subject: [PATCH 10/17] fixed tracer Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 7d88267216..f2d7c29fd4 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1504,7 +1504,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) if num_q_ffn_blocks is not None and self.mlp.__class__.__name__ == "DeepseekV3MoE": - self.mlp(hidden_states, num_q_ffn_blocks) + hidden_states = self.mlp(hidden_states, num_q_ffn_blocks) else: hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states From 8bdab23b4af41788f18c8c9a839e6b53f81989f6 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 13 May 2026 20:45:01 +0530 Subject: [PATCH 11/17] fixed ctxgather Signed-off-by: Onkar Chougule --- QEfficient/customop/ctx_scatter_gather.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 19f60886de..b1f322c606 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -166,7 +166,10 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) + batch_size = ops.Slice(ops.Shape(data), starts=[0], ends=[1], axes=[0]) + idx_seq_len = ops.Slice(ops.Shape(ctx_indices), starts=[1], ends=[2], axes=[0]) + expand_shape = ops.Concat(batch_size, idx_seq_len, axis=0) + ctx_indices = ops.Expand(ctx_indices, expand_shape) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=1) From 5f1e1f6ffe0f4f686ec82eb73580aa75407f514f Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Wed, 13 May 2026 22:55:53 +0530 Subject: [PATCH 12/17] add onnx transforms Signed-off-by: Mamta Singh --- QEfficient/base/onnx_transforms.py | 7 +++++++ examples/kimi_k2/export_kimi_k25.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 32bb0a8fa4..b3db86a55a 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -26,6 +26,10 @@ CtxScatter3D, CtxScatterFunc, CtxScatterFunc3D, + CtxScatter3DInt, + CtxScatterFunc3DInt, + CtxScatterFunc3DGeneralized, + CtxGatherFunc3DGeneralized, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherBlockedKVCB, @@ -102,6 +106,9 @@ class CustomOpTransform(BaseOnnxTransform): "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), "CastToUInt4": (CastToUInt4Func, CastToUInt4), + "CtxScatterFunc3DInt": (CtxScatterFunc3DInt, CtxScatter3DInt), + "CtxScatterFunc3DGeneralized":(CtxScatterFunc3DGeneralized, CtxScatter3D), + "CtxGatherFunc3DGeneralized": (CtxGatherFunc3DGeneralized, CtxGather3D), } @classmethod diff --git a/examples/kimi_k2/export_kimi_k25.py b/examples/kimi_k2/export_kimi_k25.py index a836f5245a..3486107731 100644 --- a/examples/kimi_k2/export_kimi_k25.py +++ b/examples/kimi_k2/export_kimi_k25.py @@ -135,7 +135,8 @@ def load_text_only_kimi(model_path: Path, num_hidden_layers: int): num_devices=TS, num_cores=16, qaic_config=qaic_config, - prefill_only=True + prefill_only=True, + use_onnx_subfunctions=True, ) qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) From 0dba1255b827086557e39b570b85f0c09a155969 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Thu, 14 May 2026 15:44:01 +0530 Subject: [PATCH 13/17] fix subfunc compilation Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index f2d7c29fd4..c18993f692 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -802,9 +802,7 @@ def forward(self, hidden_states): if self.topk_method == "noaux_tc": assert not self.training scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] + group_scores = torch.einsum("abc->ab", scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0]) # [n, n_group] group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] @@ -821,7 +819,8 @@ def forward(self, hidden_states): ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + denominator = torch.einsum("bi->b", topk_weight).unsqueeze(-1) + 1e-20 + topk_weight = topk_weight / denominator topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor @@ -1261,7 +1260,8 @@ def _cumsum_scatter_gather_update_expert_blocked( num_q_ffn_blocks = seq_len // packed_chunk_size matched_idx = _build_matched_idx_from_cumsum(T2Ei) - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + valid_rows = torch.einsum("bi->b", T2Ei.to(torch.int32)).unsqueeze(1) + row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) rw_expanded = routing_weight.unsqueeze(-1) @@ -1308,7 +1308,9 @@ def _cumsum_scatter_gather_update_expert_blocked( expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + x = valid_rows - packed_start + x = torch.where(x < 0, torch.zeros_like(x), x) + chunk_valid_rows = torch.where(x > packed_chunk_size, packed_chunk_size, x) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) @@ -1411,7 +1413,7 @@ def _forward_expert_blocked( packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, num_q_ffn_blocks=num_q_ffn_blocks, ) - return expert_out.sum(dim=0) + return torch.einsum("bij->ij", expert_out) def forward( self, hidden_states: torch.Tensor, num_q_ffn_blocks: Optional[int] = None From 9c299407049526835155e11229a217f76a48c07b Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Fri, 15 May 2026 12:17:43 +0530 Subject: [PATCH 14/17] add kv_par_mla_attention_forward Signed-off-by: Mamta Singh --- QEfficient/blocking/attention_blocking.py | 5 + .../blocking/blocked_attention_forwards.py | 180 ++++++++++++++++++ QEfficient/blocking/blocking_configurator.py | 6 +- .../models/deepseek_v3/modeling_deepseek.py | 93 +++++++++ .../transformers/models/pytorch_transforms.py | 1 + .../kimi_k2/export_kimi_k25_4_experts_only.py | 2 + 6 files changed, 286 insertions(+), 1 deletion(-) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 2ab5c03bec..e5cd8aade3 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -21,6 +21,7 @@ blocked_hqkv_attention_forward, blocked_kv_attention_forward, blocked_kv_mla_attention_forward, + blocked_kv_par_mla_attention_forward, blocked_q_attention_forward, blocked_qkv_attention_forward, ) @@ -34,6 +35,7 @@ class BlockingMode(str, Enum): QKV = "qkv" HQKV = "hqkv" BHQKV = "bhqkv" + PAR = "par" @dataclass @@ -44,6 +46,7 @@ class AttentionBlockingConfig: head_block_size: Optional[int] = None skip_kv: Optional[bool] = False num_batch_blocks: Optional[int] = None + par_num_split: Optional[int] = None def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: @@ -62,6 +65,7 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: _STRATEGIES_MLA: Dict[BlockingMode, Callable] = { BlockingMode.KV: blocked_kv_mla_attention_forward, BlockingMode.H: blocked_h_mla_attention_forward, + BlockingMode.PAR: blocked_kv_par_mla_attention_forward } @@ -224,6 +228,7 @@ def generic_blocked_mla_attention_interface( num_q_blocks=blocking_config.num_q_blocks, head_block_size=blocking_config.head_block_size, num_batch_blocks=blocking_config.num_batch_blocks, + par_num_split=blocking_config.par_num_split, score_mod=score_mod, position_bias=position_bias, sinks=sinks, diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 6aed6e49f9..2110afb98c 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -957,6 +957,186 @@ def blocked_kv_mla_attention_forward( return attn_output, attn_weights +def blocked_kv_par_mla_attention_forward( + module: nn.Module, + query: torch.Tensor, # [B, NQH, QL, D_abs] absorption-space Q + per_head_v_up: torch.Tensor, # [1, NQH, kv_lora_rank, v_head_dim] + per_head_k_up_normal: torch.Tensor, # [1, NQH, qk_nope_head_dim, kv_lora_rank] — for non-absorption K + mla_absorption: Dict[str, Any], + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + par_num_split: int, # T-dim split within each KV block (maps to NSP cores) + cache_kwargs: Dict[str, Any], + layer_idx: int, + compressed_kvs, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + GQA headpar-style MLA attention. + + Layout matches qwen3_gqa_kv_blocking_microbench.py: + q_fold = query.reshape(B, Hkv, QL*n_rep, D) # simple reshape, no permute + Q_5d = q_fold.unsqueeze(2).expand(..., split, ...) # broadcast over split + K_5d = k_block.view(B, Hkv, split, T_h, D) # consecutive T split + + Merge is two-stage offline (buffer all blocks): + Stage 1: max/exp/sum across KV blocks + Stage 2: max/exp/sum across splits + """ + B, NQH, QL, D_abs = query.shape + kv_lora_rank = module.config.kv_lora_rank + split = par_num_split + + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + else: + absorption = False + + # absorption=True : all n_rep heads in a group share the same K (= ckv||k_pe) + # ? fold: Hkv = module.num_key_value_heads, n_rep = NQH // Hkv + # absorption=False: each query head has its own K = ckv @ k_up_h + # ? cannot fold across heads; treat as Hkv=NQH, n_rep=1 + if absorption: + Hkv = getattr(module, "num_key_value_heads", 1) + n_rep = NQH // Hkv + else: + Hkv = NQH + n_rep = 1 + + # -- Q fold: reshape + unsqueeze + expand (GQA style) --------------------- + q_fold = query.reshape(B, Hkv, QL * n_rep, D_abs) + Q_5d = q_fold.unsqueeze(2).expand(B, Hkv, split, QL * n_rep, D_abs) + + ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2] + kv_block_size = -(-ctx_len // num_kv_blocks) + T_h_nom = -(-kv_block_size // split) # ceiling — nominal T per split chunk + + # kv_offsets: consecutive layout, offset of position within block + # offsets[s, t] = s*T_h_nom + t + kv_offsets = ( + torch.arange(split, device=query.device)[:, None] * T_h_nom + + torch.arange(T_h_nom, device=query.device)[None, :] + ).view(1, 1, split, 1, T_h_nom) # [1, 1, split, 1, T_h_nom] + + position_ids = cache_kwargs.get("position_ids") + current_position = position_ids.max(dim=-1).values + skip_kv = True + + max_buf: list = [] + sum_buf: list = [] + out_buf: list = [] + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + kv_len_block = ctx_len - start_index if j == num_kv_blocks - 1 else kv_block_size + end_index = start_index + kv_len_block + T_orig = kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + # Read KV block: [B, Hkv, T_orig, kv_lora_rank/qk_rope_head_dim] + ckv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, layer_idx, cache_kwargs) + k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, layer_idx, cache_kwargs) + + # K in absorption or non-absorption space: [B, Hkv, T_orig, D_abs] + if absorption: + k_block = torch.cat((ckv_block, k_pe_block), dim=-1) # [B, Hkv, T, 576] + ckv_for_v = ckv_block # [B, Hkv, T, 512] + else: + # Each query head needs its own K: expand ckv to NQH=Hkv, apply per-head k_up + # ckv_block: [B, orig_Hkv, T, kv_lora_rank] + orig_Hkv = getattr(module, "num_key_value_heads", 1) + n_rep_kv = NQH // orig_Hkv + ckv_nqh = (ckv_block.unsqueeze(2) + .expand(-1, orig_Hkv, n_rep_kv, -1, -1) + .reshape(B, NQH, T_orig, kv_lora_rank)) # [B, NQH, T, 512] + k_pe_nqh = (k_pe_block.unsqueeze(2) + .expand(-1, orig_Hkv, n_rep_kv, -1, -1) + .reshape(B, NQH, T_orig, module.config.qk_rope_head_dim)) + # per_head_k_up_normal: [1, NQH, kv_lora_rank, qk_nope_head_dim] + k_nope = torch.matmul(ckv_nqh, per_head_k_up_normal) # [B, NQH, T, 128] + k_block = torch.cat((k_nope, k_pe_nqh), dim=-1) # [B, NQH, T, 192] + ckv_for_v = ckv_nqh # [B, NQH, T, 512] + + # Pad T to multiple of split + T_blk = T_orig + pad = 0 + if T_blk % split != 0: + pad = split - (T_blk % split) + k_block = F.pad(k_block, (0, 0, 0, pad)) + ckv_for_v = F.pad(ckv_for_v, (0, 0, 0, pad)) + T_blk += pad + T_h = T_blk // split + + # 5D K/V: [B, Hkv, split, T_h, D] + K_5d = k_block.view(B, Hkv, split, T_h, D_abs) + V_5d = ckv_for_v.view(B, Hkv, split, T_h, kv_lora_rank) + + # Attention scores: [B, Hkv, split, QL*n_rep, T_h] + attn = torch.matmul(Q_5d, K_5d.transpose(-1, -2)) * scaling + + # Padding mask + if pad > 0: + chunk_start = torch.arange(split, device=attn.device) * T_h + valid_in_chunk = T_orig - chunk_start + k_idx = torch.arange(T_h, device=attn.device) + pad_mask = k_idx.unsqueeze(0) >= valid_in_chunk.unsqueeze(1) # [split, T_h] + attn = attn.masked_fill(pad_mask.view(1, 1, split, 1, T_h), -3.0e4) + + # Causal mask: offsets within block vs query position + off = kv_offsets if T_h == T_h_nom else kv_offsets[:, :, :, :, :T_h] + causal_mask = off > (position_ids - start_index)[:, None, None, :, None] + attn = attn.masked_fill(causal_mask, -3.0e4) + + m_blk = attn.max(dim=-1).values # [B, Hkv, split, QL*n_rep] + exp_blk = torch.exp(attn - m_blk.unsqueeze(-1)) + + if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + m_blk = torch.where(skip_future, torch.full_like(m_blk, float(MIN_MASKED_ATTENTION_VALUE)), m_blk) + exp_blk = torch.where(skip_future, torch.zeros_like(exp_blk), exp_blk) + + sum_blk = exp_blk.sum(dim=-1) # [B, Hkv, split, QL*n_rep] + out_blk = torch.matmul(exp_blk, V_5d) # [B, Hkv, split, QL*n_rep, kv_lora_rank] + + if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + sum_blk = torch.where(skip_future, torch.zeros_like(sum_blk), sum_blk) + out_blk = torch.where(skip_future, torch.zeros_like(out_blk), out_blk) + + max_buf.append(m_blk) + sum_buf.append(sum_blk) + out_buf.append(out_blk) + + # -- Stage 1: merge across KV blocks -------------------------------------- + max_stk = torch.stack(max_buf) # [nkvb, B, Hkv, split, QL*n_rep] + sum_stk = torch.stack(sum_buf) + out_stk = torch.stack(out_buf) # [nkvb, B, Hkv, split, QL*n_rep, kv_lora_rank] + m1 = max_stk.max(dim=0).values + w1 = torch.exp(max_stk - m1.unsqueeze(0)) + s1 = (w1 * sum_stk).sum(dim=0) # [B, Hkv, split, QL*n_rep] + o1 = (w1.unsqueeze(-1) * out_stk).sum(dim=0) # [B, Hkv, split, QL*n_rep, kv_lora_rank] + + # -- Stage 2: merge across splits ----------------------------------------- + m2 = m1.max(dim=2).values # [B, Hkv, QL*n_rep] + w2 = torch.exp(m1 - m2.unsqueeze(2)) + s2 = (w2 * s1).sum(dim=2) + o2 = (w2.unsqueeze(-1) * o1).sum(dim=2) # [B, Hkv, QL*n_rep, kv_lora_rank] + output = o2 / s2.unsqueeze(-1) + + # -- Unfold + v_up (GQA style) --------------------------------------------- + # [B, Hkv, QL*n_rep, kv_lora_rank] ? [B, NQH, QL, kv_lora_rank] + output = output.view(B, Hkv, n_rep, QL, kv_lora_rank).reshape(B, NQH, QL, kv_lora_rank) + attn_output = torch.matmul(output, per_head_v_up) # [B, NQH, QL, v_head_dim] + attn_output = attn_output.transpose(1, 2).contiguous() # [B, QL, NQH, v_head_dim] + + return attn_output, None + + def blocked_h_mla_attention_forward( module: nn.Module, q_a_proj_out: torch.Tensor, diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index deed73a7bf..2ca059fb34 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -288,7 +288,7 @@ def build_transformer_blocking_config( ) if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): - if "kv" in blocking_mode: + if "kv" in blocking_mode or "par" in blocking_mode: attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len) resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") @@ -343,6 +343,10 @@ def build_transformer_blocking_config_for_transform( if qaic_config.get("num_batch_blocks", False) and enable_blocking and "b" in blocking_mode: mode_from_config = "b" + mode_from_config blocking_config.num_batch_blocks = _get_valid_num_blocks(qaic_config, "num_batch_blocks") + if qaic_config.get("par_num_split", False) and qaic_config.get("num_kv_blocks", False) and enable_blocking and "par" in blocking_mode: + mode_from_config = "par" + mode_from_config + blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") + blocking_config.par_num_split = _get_valid_num_blocks(qaic_config, "par_num_split") # check if qaic config did not provide any blocking details if mode_from_config == "": diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index c18993f692..e2b7a47ff7 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -407,6 +407,84 @@ def fused_forward_kv_blocking( return attn_output, None, compressed_kvs + def fused_forward_par_kv_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + # -- KV compression (write to cache) ---------------------------------- + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(kva) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if compressed_kvs is not None: + compressed_kvs.write_only_ckv(kva, self.layer_idx, cache_kwargs) + + cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + if compressed_kvs is not None: + compressed_kvs.write_only_k_pe(k_pe, self.layer_idx, cache_kwargs) + + # -- Build query in absorption space ----------------------------------- + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) + else: + absorption = False + + if absorption: + if online: + qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk) + query = torch.cat((dq_qup_kupT, q_pe), dim=-1) # [B, num_heads, q_len, d_abs] + else: + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query = torch.cat((q_nope, q_pe), dim=-1) + + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + query=query, + per_head_v_up=self.per_head_v_up, + per_head_k_up_normal=self.per_head_k_up_normal, + attention_mask=attention_mask, + scaling=self.softmax_scale, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + compressed_kvs=compressed_kvs, + mla_absorption=mla_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, None, compressed_kvs + def fused_forward_orig( self, hidden_states: torch.Tensor, @@ -572,6 +650,21 @@ def fused_forward( mla_absorption, **kwargs, ) + elif getattr(blocking_config, "mode", None) == "par": + return self.fused_forward_par_kv_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) else: return self.fused_forward_orig( hidden_states, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 3792104a04..7a2b86b94a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1048,6 +1048,7 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "fused_forward": QEffDeepseekV3Attention.fused_forward, "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, "fused_forward_kv_blocking": QEffDeepseekV3Attention.fused_forward_kv_blocking, + "fused_forward_par_kv_blocking": QEffDeepseekV3Attention.fused_forward_par_kv_blocking, "fused_forward_orig": QEffDeepseekV3Attention.fused_forward_orig, "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, }, diff --git a/examples/kimi_k2/export_kimi_k25_4_experts_only.py b/examples/kimi_k2/export_kimi_k25_4_experts_only.py index 384d0f9da0..4f2571e2a4 100644 --- a/examples/kimi_k2/export_kimi_k25_4_experts_only.py +++ b/examples/kimi_k2/export_kimi_k25_4_experts_only.py @@ -39,6 +39,8 @@ # qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} # for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "par", "par_num_split": 4, "num_kv_blocks": 8} + MODEL_PATH = Path( "/home/huggingface_hub/models--moonshotai--Kimi-K2.5/snapshots/54383e83fa343a1331754112fb9e3410c55efa2f" From 497c0033d00f0f451e2dac309374e88cffe3cc29 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 18 May 2026 14:53:40 +0530 Subject: [PATCH 15/17] kv_par latest patch Signed-off-by: Mamta Singh --- .../blocking/blocked_attention_forwards.py | 243 ++++++++---------- 1 file changed, 113 insertions(+), 130 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 2110afb98c..f1061747af 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -959,79 +959,67 @@ def blocked_kv_mla_attention_forward( def blocked_kv_par_mla_attention_forward( module: nn.Module, - query: torch.Tensor, # [B, NQH, QL, D_abs] absorption-space Q - per_head_v_up: torch.Tensor, # [1, NQH, kv_lora_rank, v_head_dim] - per_head_k_up_normal: torch.Tensor, # [1, NQH, qk_nope_head_dim, kv_lora_rank] — for non-absorption K - mla_absorption: Dict[str, Any], + query: torch.Tensor, + per_head_v_up: torch.Tensor, + per_head_k_up_normal: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, num_kv_blocks: int, - par_num_split: int, # T-dim split within each KV block (maps to NSP cores) + par_num_split: int, cache_kwargs: Dict[str, Any], layer_idx: int, compressed_kvs, + mla_absorption: Optional[Dict[str, Any]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - GQA headpar-style MLA attention. - - Layout matches qwen3_gqa_kv_blocking_microbench.py: - q_fold = query.reshape(B, Hkv, QL*n_rep, D) # simple reshape, no permute - Q_5d = q_fold.unsqueeze(2).expand(..., split, ...) # broadcast over split - K_5d = k_block.view(B, Hkv, split, T_h, D) # consecutive T split - - Merge is two-stage offline (buffer all blocks): - Stage 1: max/exp/sum across KV blocks - Stage 2: max/exp/sum across splits - """ - B, NQH, QL, D_abs = query.shape + batch_size, num_heads, q_len, query_dim = query.shape kv_lora_rank = module.config.kv_lora_rank - split = par_num_split if mla_absorption is not None: absorption = mla_absorption.get("absorption", False) else: absorption = False - # absorption=True : all n_rep heads in a group share the same K (= ckv||k_pe) - # ? fold: Hkv = module.num_key_value_heads, n_rep = NQH // Hkv - # absorption=False: each query head has its own K = ckv @ k_up_h - # ? cannot fold across heads; treat as Hkv=NQH, n_rep=1 + position_ids = cache_kwargs.get("position_ids") + if absorption: - Hkv = getattr(module, "num_key_value_heads", 1) - n_rep = NQH // Hkv + num_key_value_heads = getattr(module, "num_key_value_heads", 1) + num_heads_per_kv = num_heads // num_key_value_heads else: - Hkv = NQH - n_rep = 1 + num_key_value_heads = num_heads + num_heads_per_kv = 1 - # -- Q fold: reshape + unsqueeze + expand (GQA style) --------------------- - q_fold = query.reshape(B, Hkv, QL * n_rep, D_abs) - Q_5d = q_fold.unsqueeze(2).expand(B, Hkv, split, QL * n_rep, D_abs) + split = par_num_split + split = max(1, _normalize_int(split)) - ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2] - kv_block_size = -(-ctx_len // num_kv_blocks) - T_h_nom = -(-kv_block_size // split) # ceiling — nominal T per split chunk + q_fold = query.reshape(batch_size, num_key_value_heads, q_len * num_heads_per_kv, query_dim) + q_5d = q_fold.unsqueeze(2).expand( + batch_size, num_key_value_heads, split, q_len * num_heads_per_kv, query_dim + ) - # kv_offsets: consecutive layout, offset of position within block - # offsets[s, t] = s*T_h_nom + t + ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2] + kv_block_size = -(-ctx_len // num_kv_blocks) + split_block_size = -(-kv_block_size // split) kv_offsets = ( - torch.arange(split, device=query.device)[:, None] * T_h_nom - + torch.arange(T_h_nom, device=query.device)[None, :] - ).view(1, 1, split, 1, T_h_nom) # [1, 1, split, 1, T_h_nom] + torch.arange(split, device=query.device)[:, None] * split_block_size + + torch.arange(split_block_size, device=query.device)[None, :] + ).view(1, 1, split, 1, split_block_size) - position_ids = cache_kwargs.get("position_ids") + masked_tensor = torch.tensor(-3.0e4, dtype=query.dtype, device=query.device) current_position = position_ids.max(dim=-1).values skip_kv = True - max_buf: list = [] - sum_buf: list = [] - out_buf: list = [] + max_blocks = [] + sum_blocks = [] + output_blocks = [] - for j in range(num_kv_blocks): - start_index = j * kv_block_size - kv_len_block = ctx_len - start_index if j == num_kv_blocks - 1 else kv_block_size - end_index = start_index + kv_len_block - T_orig = kv_len_block + for block_idx in range(num_kv_blocks): + start_index = block_idx * kv_block_size + if block_idx == num_kv_blocks - 1: + kv_len_block = ctx_len - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block skip_future = None if skip_kv: @@ -1040,99 +1028,94 @@ def blocked_kv_par_mla_attention_forward( if skip_future.item(): break - # Read KV block: [B, Hkv, T_orig, kv_lora_rank/qk_rope_head_dim] - ckv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, layer_idx, cache_kwargs) + compressed_kv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, layer_idx, cache_kwargs) k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, layer_idx, cache_kwargs) - # K in absorption or non-absorption space: [B, Hkv, T_orig, D_abs] if absorption: - k_block = torch.cat((ckv_block, k_pe_block), dim=-1) # [B, Hkv, T, 576] - ckv_for_v = ckv_block # [B, Hkv, T, 512] + key_block = torch.cat((compressed_kv_block, k_pe_block), dim=-1) + value_block = compressed_kv_block else: - # Each query head needs its own K: expand ckv to NQH=Hkv, apply per-head k_up - # ckv_block: [B, orig_Hkv, T, kv_lora_rank] - orig_Hkv = getattr(module, "num_key_value_heads", 1) - n_rep_kv = NQH // orig_Hkv - ckv_nqh = (ckv_block.unsqueeze(2) - .expand(-1, orig_Hkv, n_rep_kv, -1, -1) - .reshape(B, NQH, T_orig, kv_lora_rank)) # [B, NQH, T, 512] - k_pe_nqh = (k_pe_block.unsqueeze(2) - .expand(-1, orig_Hkv, n_rep_kv, -1, -1) - .reshape(B, NQH, T_orig, module.config.qk_rope_head_dim)) - # per_head_k_up_normal: [1, NQH, kv_lora_rank, qk_nope_head_dim] - k_nope = torch.matmul(ckv_nqh, per_head_k_up_normal) # [B, NQH, T, 128] - k_block = torch.cat((k_nope, k_pe_nqh), dim=-1) # [B, NQH, T, 192] - ckv_for_v = ckv_nqh # [B, NQH, T, 512] - - # Pad T to multiple of split - T_blk = T_orig - pad = 0 - if T_blk % split != 0: - pad = split - (T_blk % split) - k_block = F.pad(k_block, (0, 0, 0, pad)) - ckv_for_v = F.pad(ckv_for_v, (0, 0, 0, pad)) - T_blk += pad - T_h = T_blk // split - - # 5D K/V: [B, Hkv, split, T_h, D] - K_5d = k_block.view(B, Hkv, split, T_h, D_abs) - V_5d = ckv_for_v.view(B, Hkv, split, T_h, kv_lora_rank) - - # Attention scores: [B, Hkv, split, QL*n_rep, T_h] - attn = torch.matmul(Q_5d, K_5d.transpose(-1, -2)) * scaling - - # Padding mask - if pad > 0: - chunk_start = torch.arange(split, device=attn.device) * T_h - valid_in_chunk = T_orig - chunk_start - k_idx = torch.arange(T_h, device=attn.device) - pad_mask = k_idx.unsqueeze(0) >= valid_in_chunk.unsqueeze(1) # [split, T_h] - attn = attn.masked_fill(pad_mask.view(1, 1, split, 1, T_h), -3.0e4) + original_kv_heads = getattr(module, "num_key_value_heads", 1) + num_repeats = num_heads // original_kv_heads + compressed_kv_block = ( + compressed_kv_block.unsqueeze(2) + .expand(-1, original_kv_heads, num_repeats, -1, -1) + .reshape(batch_size, num_heads, kv_len_block, kv_lora_rank) + ) + k_pe_block = ( + k_pe_block.unsqueeze(2) + .expand(-1, original_kv_heads, num_repeats, -1, -1) + .reshape(batch_size, num_heads, kv_len_block, module.config.qk_rope_head_dim) + ) + k_nope_block = torch.matmul(compressed_kv_block, per_head_k_up_normal) + key_block = torch.cat((k_nope_block, k_pe_block), dim=-1) + value_block = compressed_kv_block - # Causal mask: offsets within block vs query position - off = kv_offsets if T_h == T_h_nom else kv_offsets[:, :, :, :, :T_h] - causal_mask = off > (position_ids - start_index)[:, None, None, :, None] - attn = attn.masked_fill(causal_mask, -3.0e4) + pad = 0 + padded_kv_len = kv_len_block + if padded_kv_len % split != 0: + pad = split - (padded_kv_len % split) + key_block = F.pad(key_block, (0, 0, 0, pad)) + value_block = F.pad(value_block, (0, 0, 0, pad)) + padded_kv_len += pad - m_blk = attn.max(dim=-1).values # [B, Hkv, split, QL*n_rep] - exp_blk = torch.exp(attn - m_blk.unsqueeze(-1)) + per_split_kv_len = padded_kv_len // split + key_5d = key_block.view(batch_size, num_key_value_heads, split, per_split_kv_len, query_dim) + value_5d = value_block.view(batch_size, num_key_value_heads, split, per_split_kv_len, kv_lora_rank) - if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): - m_blk = torch.where(skip_future, torch.full_like(m_blk, float(MIN_MASKED_ATTENTION_VALUE)), m_blk) - exp_blk = torch.where(skip_future, torch.zeros_like(exp_blk), exp_blk) + attn_weights_block = torch.matmul(q_5d, key_5d.transpose(-1, -2)) * scaling + + if pad > 0: + chunk_start = torch.arange(split, device=query.device) * per_split_kv_len + valid_in_chunk = kv_len_block - chunk_start + kv_indices = torch.arange(per_split_kv_len, device=query.device) + pad_mask = kv_indices.unsqueeze(0) >= valid_in_chunk.unsqueeze(1) + attn_weights_block = torch.where( + pad_mask.view(1, 1, split, 1, per_split_kv_len), masked_tensor, attn_weights_block + ) + + offsets = kv_offsets if per_split_kv_len == split_block_size else kv_offsets[:, :, :, :, :per_split_kv_len] + causal_mask = offsets > (position_ids - start_index)[:, None, None, :, None] + attn_weights_block = torch.where(causal_mask, masked_tensor, attn_weights_block) - sum_blk = exp_blk.sum(dim=-1) # [B, Hkv, split, QL*n_rep] - out_blk = torch.matmul(exp_blk, V_5d) # [B, Hkv, split, QL*n_rep, kv_lora_rank] + block_max = attn_weights_block.max(dim=-1).values + block_exp = torch.exp(attn_weights_block - block_max.unsqueeze(-1)) + if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + block_max = torch.where( + skip_future, torch.full_like(block_max, float(MIN_MASKED_ATTENTION_VALUE)), block_max + ) + block_exp = torch.where(skip_future, torch.zeros_like(block_exp), block_exp) + block_sum = block_exp.sum(dim=-1) + block_output = torch.matmul(block_exp, value_5d) if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): - sum_blk = torch.where(skip_future, torch.zeros_like(sum_blk), sum_blk) - out_blk = torch.where(skip_future, torch.zeros_like(out_blk), out_blk) - - max_buf.append(m_blk) - sum_buf.append(sum_blk) - out_buf.append(out_blk) - - # -- Stage 1: merge across KV blocks -------------------------------------- - max_stk = torch.stack(max_buf) # [nkvb, B, Hkv, split, QL*n_rep] - sum_stk = torch.stack(sum_buf) - out_stk = torch.stack(out_buf) # [nkvb, B, Hkv, split, QL*n_rep, kv_lora_rank] - m1 = max_stk.max(dim=0).values - w1 = torch.exp(max_stk - m1.unsqueeze(0)) - s1 = (w1 * sum_stk).sum(dim=0) # [B, Hkv, split, QL*n_rep] - o1 = (w1.unsqueeze(-1) * out_stk).sum(dim=0) # [B, Hkv, split, QL*n_rep, kv_lora_rank] - - # -- Stage 2: merge across splits ----------------------------------------- - m2 = m1.max(dim=2).values # [B, Hkv, QL*n_rep] - w2 = torch.exp(m1 - m2.unsqueeze(2)) - s2 = (w2 * s1).sum(dim=2) - o2 = (w2.unsqueeze(-1) * o1).sum(dim=2) # [B, Hkv, QL*n_rep, kv_lora_rank] - output = o2 / s2.unsqueeze(-1) - - # -- Unfold + v_up (GQA style) --------------------------------------------- - # [B, Hkv, QL*n_rep, kv_lora_rank] ? [B, NQH, QL, kv_lora_rank] - output = output.view(B, Hkv, n_rep, QL, kv_lora_rank).reshape(B, NQH, QL, kv_lora_rank) - attn_output = torch.matmul(output, per_head_v_up) # [B, NQH, QL, v_head_dim] - attn_output = attn_output.transpose(1, 2).contiguous() # [B, QL, NQH, v_head_dim] + block_sum = torch.where(skip_future, torch.zeros_like(block_sum), block_sum) + block_output = torch.where(skip_future, torch.zeros_like(block_output), block_output) + + max_blocks.append(block_max) + sum_blocks.append(block_sum) + output_blocks.append(block_output) + + max_stacked = torch.stack(max_blocks) + sum_stacked = torch.stack(sum_blocks) + output_stacked = torch.stack(output_blocks) + + max_across_blocks = max_stacked.max(dim=0).values + weights_across_blocks = torch.exp(max_stacked - max_across_blocks.unsqueeze(0)) + sum_across_blocks = (weights_across_blocks * sum_stacked).sum(dim=0) + output_across_blocks = (weights_across_blocks.unsqueeze(-1) * output_stacked).sum(dim=0) + + max_across_splits = max_across_blocks.max(dim=2).values + weights_across_splits = torch.exp(max_across_blocks - max_across_splits.unsqueeze(2)) + sum_across_splits = (weights_across_splits * sum_across_blocks).sum(dim=2) + output_across_splits = (weights_across_splits.unsqueeze(-1) * output_across_blocks).sum(dim=2) + + output = output_across_splits / sum_across_splits.unsqueeze(-1) + output = output.view(batch_size, num_key_value_heads, num_heads_per_kv, q_len, kv_lora_rank).reshape( + batch_size, num_heads, q_len, kv_lora_rank + ) + attn_output = torch.matmul(output, per_head_v_up) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None From bac4cadb9a664ef12a4b2887ba7570b5b5d04143 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 18 May 2026 23:13:18 +0530 Subject: [PATCH 16/17] fix(0415): align prefill MoE chunk export with packed dispatch Signed-off-by: Mamta Singh --- .../transformers/models/deepseek_v3/modeling_deepseek.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index e2b7a47ff7..7a5bb9a4ce 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -869,9 +869,7 @@ def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) valid_dest = valid_prefix - 1 scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) - # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this - # can be switched back to ``torch.full_like(token_idx, int32_max)``. - matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = torch.full_like(token_idx, int32_max) matched_idx = CtxScatterFunc3DInt.apply( matched_idx.unsqueeze(-1), scatter_pos, From f73721ee300af9692cdd595cea00e94c32aa75a5 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 18 May 2026 23:22:17 +0530 Subject: [PATCH 17/17] fix(0415): align prefill MoE chunk export with packed dispatch 2 Signed-off-by: Mamta Singh --- QEfficient/transformers/models/modeling_auto.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f1b8b6e3cc..e605f80ecd 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2993,7 +2993,9 @@ def get_seq_len_and_handle_specialized_prefill_model( self.hash_params["prefill_only"] = True if enable_chunking: self.hash_params["chunking"] = True - return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + seq_len = max(prefill_seq_len or 0, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + self.hash_params["chunking_seq_len"] = seq_len + return seq_len num_q_blocks = ( self.hash_params["blocking_config"].num_q_blocks if self.hash_params.get("blocking_kwargs", None) else None @@ -3104,6 +3106,7 @@ def export( self.hash_params.pop("NUM_FFN_BLOCKS", None) self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) + self.hash_params.pop("chunking_seq_len", None) if kwargs.get("retain_full_kv", False): kv_cache_shape[2] = seq_len + ( self.model.config.sliding_window if self.model.config.sliding_window is not None else 0