From 06d05a2796025b102757011f34bdc4cb466a33b4 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 28 Apr 2026 13:24:39 +0530 Subject: [PATCH 1/5] add int4 changes Signed-off-by: Mamta Singh --- QEfficient/base/onnx_transforms.py | 2 + QEfficient/base/pytorch_transforms.py | 13 +- QEfficient/customop/matmulnbits.py | 88 +++- QEfficient/customop/quantization_ops.py | 149 +++++++ .../models/deepseek_v3/modeling_deepseek.py | 380 +++++++++++++++--- .../transformers/models/modeling_auto.py | 2 + .../transformers/models/pytorch_transforms.py | 6 +- .../quantizers/quant_transforms.py | 39 ++ .../quantizer_compressed_tensors.py | 158 +++++++- 9 files changed, 761 insertions(+), 76 deletions(-) create mode 100644 QEfficient/customop/quantization_ops.py diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index c27e3cc704..32bb0a8fa4 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -39,6 +39,7 @@ CtxScatterFuncCB, CtxScatterFuncCB3D, ) +from QEfficient.customop.quantization_ops import CastToUInt4, CastToUInt4Func from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc from QEfficient.utils.constants import FILE_CHUNK_SIZE_DEFAULT, ONNX_EXPORT_OPSET, SIZE_THRESHOLD_DEFAULT @@ -100,6 +101,7 @@ class CustomOpTransform(BaseOnnxTransform): "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB), "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), + "CastToUInt4": (CastToUInt4Func, CastToUInt4), } @classmethod diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index 812177eac2..f55e5872ff 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -4,8 +4,9 @@ # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- + from types import MethodType -from typing import Callable, Dict, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type from torch import nn @@ -97,6 +98,7 @@ class ModuleMutatorTransform(PytorchTransform): """ _match_class: nn.Module + _match_string: Optional[str] = None @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: @@ -135,7 +137,14 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__) ): for orig_method_name, mapped_method in repl_method_map.items(): - setattr(module, orig_method_name, MethodType(mapped_method, module)) + parts = orig_method_name.split(".") + if len(parts) > 1: + target = module + for part in parts[:-1]: + target = getattr(target, part) + setattr(target, parts[-1], MethodType(mapped_method, target)) + else: + setattr(module, orig_method_name, MethodType(mapped_method, module)) if hasattr(module, "__qeff_init__"): module.__qeff_init__() diff --git a/QEfficient/customop/matmulnbits.py b/QEfficient/customop/matmulnbits.py index e6249b0ad3..468f32e641 100644 --- a/QEfficient/customop/matmulnbits.py +++ b/QEfficient/customop/matmulnbits.py @@ -14,7 +14,10 @@ class QuantLinearTorchFunction(torch.autograd.Function): @staticmethod def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): - input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) + if qself_qzeros is None: + input_tuple = (x, qself_qweight, qself_scales) + else: + input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) input_tuple += (g_idx,) if g_idx is not None else () return g.op( "com.microsoft::MatMulNBits", @@ -28,7 +31,10 @@ def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group @staticmethod def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): + if qself_qzeros is None: + qself_qzeros = 2 ^ (bits - 1) if torch.onnx.is_in_onnx_export(): + # For faster export return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float() fp_weight = dequantize_blockwise_bits( qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features @@ -40,8 +46,7 @@ def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, grou def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols): if bits != 4: raise ValueError("Only bits=4 is supported for executing quantized model") - if group_size != 128: - raise ValueError("Only group_size=128 is supported for executing quantized model") + expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1) aligned_scale = scale.reshape(*quant_values.shape[:-1], 1) @@ -88,20 +93,20 @@ def __init__(self, bits, group_size, in_features, out_features, bias): q_rows = in_features // self.group_size self.register_buffer( "qweight", - torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8), + torch.empty((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8), ) self.register_buffer( "qzeros", - torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), + torch.empty((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), ) self.register_buffer( - "scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16) + "scales", torch.empty((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16) ) self.register_buffer( "g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32) ) if bias: - self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16)) + self.register_buffer("bias", torch.empty((out_features), dtype=torch.float16)) else: self.bias = None @@ -180,3 +185,72 @@ def forward(self, inputs): ) out = out + self.bias if self.bias is not None else out return out + + +class QMOE(torch.autograd.Function): + @staticmethod + def symbolic( + g, + x, + router_weights, + fc1_experts_weights, + fc1_scales, + fc2_experts_weights, + fc2_scales, + fc3_experts_weights, + fc3_scales, + router_probs, + activation_type, + block_size, + expert_weight_bits, + k, + ): + qmoe_out = g.op( + "com.microsoft::QMoE", + x, + router_weights, + router_probs, + fc1_experts_weights, + fc1_scales, + fc2_experts_weights, + fc2_scales, + fc3_experts_weights, + fc3_scales, + outputs=1, + activation_type_s=activation_type, # <-- _s suffix for string + block_size_i=block_size, + expert_weight_bits_i=expert_weight_bits, + k_i=k, + ) + + # # Create axes=-1 as an explicit int64 constant tensor + # axes = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + + # # Compute mean of router_probs along the last axis, keepdims for broadcasting + # router_probs_mean = g.op("ReduceMean", router_probs, axes, keepdims_i=1) + + # Multiply qmoe_out with the averaged router_probs + return qmoe_out + # return g.op("Mul", qmoe_out, router_probs_mean)) + + @staticmethod + def forward( + ctx, + x, + router_weights, + fc1_experts_weights, + fc1_scales, + fc2_experts_weights, + fc2_scales, + fc3_experts_weights, + fc3_scales, + router_probs, + activation_type, + block_size, + expert_weight_bits, + k, + ): + # Dummy forward: simulate qmoe_out as zeros_like(x), then apply ReduceMean * Mul + qmoe_out = torch.zeros_like(x) + router_probs_mean = router_probs.mean(dim=-1, keepdim=True) + return qmoe_out * router_probs_mean diff --git a/QEfficient/customop/quantization_ops.py b/QEfficient/customop/quantization_ops.py new file mode 100644 index 0000000000..3878804c7d --- /dev/null +++ b/QEfficient/customop/quantization_ops.py @@ -0,0 +1,149 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import onnxscript +import torch +from onnx import TensorProto + +from QEfficient.utils import constants + +ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) + + +@onnxscript.script(onnxscript.values.Opset("com.qti.aisw.onnx", 1)) +def CastToUInt4(weight_packed: onnxscript.UINT8) -> onnxscript.UINT8: + """ + Unpack packed uint8 weights into uint4 values and cast output to UINT4. + Supports N-D input: all leading dimensions are preserved; only the last + dimension (in_features // 2) is doubled to (in_features). + + Input: (..., in_features // 2) UINT8 + Each byte holds two nibbles: byte = (w_y << 4) | (w_x & 0x0F) + Output: (..., in_features) UINT4, values in [0, 15] + + Operations: + w_x = weight_packed % 16 (lower nibble) + w_y = (weight_packed >> 4) % 16 (upper nibble) + stacked = concat([w_x, w_y], axis=-1) after unsqueeze + → (..., in//2, 2) + leading_dims = shape[:-1] + new_shape = [...leading_dims, last_dim * 2] + reshaped = reshape(stacked, new_shape) + output = Cast(reshaped, to=UINT4) + """ + sixteen = ops.CastLike(ops.Constant(value_ints=[16]), weight_packed) + + # Lower nibble: weight_packed & 0x0F = weight_packed % 16 + w_x = ops.Mod(weight_packed, sixteen) + + # Upper nibble: (weight_packed >> 4) & 0x0F + shift = ops.CastLike(ops.Constant(value_ints=[4]), weight_packed) + w_shifted = ops.BitShift(weight_packed, shift, direction="RIGHT") + w_y = ops.Mod(w_shifted, sixteen) + + # Stack along a new last dim → (..., in_features//2, 2) + w_x_unsq = ops.Unsqueeze(w_x, [-1]) + w_y_unsq = ops.Unsqueeze(w_y, [-1]) + stacked = ops.Concat(w_x_unsq, w_y_unsq, axis=-1) + + # N-D aware reshape: preserve all leading dims, double the last dim. + # packed_shape = [d0, d1, ..., last_dim] + packed_shape = ops.Shape(weight_packed) + # All dims except the last: [d0, d1, ...] + leading_dims = ops.Slice(packed_shape, starts=[0], ends=[-1], axes=[0]) + # Last dim only: [last_dim] + last_dim = ops.Slice(packed_shape, starts=[-1], ends=[2147483647], axes=[0]) + # Double the last dim: [last_dim * 2] + last_dim_doubled = ops.Mul(last_dim, ops.Constant(value_ints=[2])) + # New shape: [d0, d1, ..., last_dim * 2] + new_shape = ops.Concat(leading_dims, last_dim_doubled, axis=0) + reshaped = ops.Reshape(stacked, new_shape) + + # Cast to UINT4 — data_type value is version-dependent (21 in ONNX 1.18, 23 in newer) + return ops.Cast(reshaped, to=int(TensorProto.UINT4)) + + +class CastToUInt4Func(torch.autograd.Function): + """ + Custom op: unpacks packed uint8 → uint8 (values 0-15) in PyTorch. + In ONNX the custom op subgraph includes a Cast → UINT4 as its last step. + Supports N-D input: all leading dimensions are preserved. + + PyTorch forward : packed uint8 (..., in//2) → uint8 (..., in), values [0, 15] + ONNX symbolic : emits CastToUInt4 node (com.qti.aisw.onnx) + The subgraph ends with Cast → UINT4. + """ + + @staticmethod + def forward(weight_packed: torch.Tensor) -> torch.Tensor: + w_x = weight_packed & 0x0F # lower nibble, (..., in//2), range [0, 15] + w_y = (weight_packed >> 4) & 0x0F # upper nibble, (..., in//2), range [0, 15] + # New shape: all leading dims unchanged, last dim doubled + new_shape = list(weight_packed.shape[:-1]) + [weight_packed.shape[-1] * 2] + return torch.stack( + [w_x, w_y], dim=-1 + ).reshape( + new_shape + ) # Can't add a cast operation to uint4 here, as its not supported in pytorch; The ONNX export will handle the cast to IINT4 in the symbolic method. + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, weight_packed: torch.Value) -> torch.Value: + output = g.onnxscript_op(CastToUInt4, weight_packed) + return output + + +class DequantizeLinearFunc(torch.autograd.Function): + """ + Emits a standard ONNX DequantizeLinear node (ai.onnx domain, not custom). + + Symmetric blockwise quantization — no zero_point: + output = x * scale (per block along the last axis) + + Supports N-D input: + weight_unpacked : (..., in_features) — quantized values + scale : (..., num_blocks) — per-block scales + block_size : int — elements per block + + PyTorch forward : expand blockwise scale along last dim, multiply + ONNX symbolic : DequantizeLinear(weight_unpacked, scale, + axis=2, block_size=block_size) + axis=2 for 3D input (2, out_features, in_features). + No zero_point input (symmetric). + """ + + @staticmethod + def forward( + weight_unpacked: torch.Tensor, scale: torch.Tensor, zeros: torch.Tensor, block_size: int + ) -> torch.Tensor: + # Expand per-block scale → per-element scale along last dim + scale_expanded = scale.repeat_interleave(block_size, dim=-1) + zeros_expanded = zeros.repeat_interleave(block_size, dim=-1) + return (weight_unpacked.to(torch.int8) - zeros_expanded.to(torch.int8)) * scale_expanded + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic( + g: torch.Graph, weight_unpacked: torch.Value, scale: torch.Value, zeros: torch.Value, block_size: int + ) -> torch.Value: + # Standard DequantizeLinear: symmetric (no zero_point), blockwise. + # Input is 3D: (2, out_features, in_features) → axis=2 (last dim). + # DequantizeLinear natively supports batch dimensions. + return g.op( + "DequantizeLinear", + weight_unpacked, + scale, + zeros, + axis_i=2, + block_size_i=block_size, + ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 0e1f270e20..6c3bd0c083 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -20,6 +20,8 @@ generic_blocked_attention_interface, generic_blocked_mla_attention_interface, ) +from QEfficient.customop.matmulnbits import QMOE, QuantLinearTorchFunction +from QEfficient.customop.quantization_ops import CastToUInt4Func, DequantizeLinearFunc from QEfficient.customop.rms_norm import CustomRMSNormFunc from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -740,30 +742,152 @@ def forward( ) +class QEffDeepseekMoEGate(nn.Module): + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + if self.topk_method == "noaux_tc": + assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + router_probs = tmp_scores + router_weights = scores + return topk_idx, topk_weight, router_probs, router_weights + + class QEffDeepseekV3MoE(nn.Module): def __qeff_init__( self, ): - self.all_gate_proj = torch.nn.Parameter( - torch.cat( - [exp.gate_proj.compressor.decompress_module(exp.gate_proj).T.unsqueeze(0) for exp in self.experts], - dim=0, - ) + # Get common parameters from first expert + first_expert = self.experts[0] + self.bits = first_expert.gate_proj.bits + self.group_size = first_expert.gate_proj.group_size + self.act_fn = first_expert.act_fn + assert first_expert.gate_proj.act_order == first_expert.up_proj.act_order == first_expert.down_proj.act_order, ( + "act_order mismatch" ) - self.all_up_proj = torch.nn.Parameter( - torch.cat( - [exp.up_proj.compressor.decompress_module(exp.up_proj).T.unsqueeze(0) for exp in self.experts], dim=0 - ) + self.act_order = first_expert.gate_proj.act_order + + # Store dimensions for dequantization + self.in_features_gate, self.out_features_gate = ( + first_expert.gate_proj.in_features, + first_expert.gate_proj.out_features, ) - self.all_down_proj = torch.nn.Parameter( - torch.cat( - [exp.down_proj.compressor.decompress_module(exp.down_proj).T.unsqueeze(0) for exp in self.experts], - dim=0, - ) + self.in_features_up, self.out_features_up = first_expert.up_proj.in_features, first_expert.up_proj.out_features + self.in_features_down, self.out_features_down = ( + first_expert.down_proj.in_features, + first_expert.down_proj.out_features, + ) + + # Stack all parameters along a new dimension (expert dimension) + self.all_gate_qweight = torch.nn.Parameter( + torch.stack([exp.gate_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // 2 + ), + requires_grad=False, + ) + self.all_gate_scales = torch.nn.Parameter( + torch.stack([exp.gate_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // self.group_size + ), + requires_grad=False, + ) + # TODO: Since we know qzeros is always 8 -> Just embed this once into the operator as parameter -> explore this later + self.all_gate_qzeros = torch.nn.Parameter( + torch.stack([exp.gate_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_gate, self.in_features_gate // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_gate_gidx = torch.nn.Parameter( + torch.stack([exp.gate_proj.g_idx for exp in self.experts], dim=0), requires_grad=False ) - self.act_fn = self.experts[0].act_fn - def moe( + self.all_up_qweight = torch.nn.Parameter( + torch.stack([exp.up_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // 2 + ), + requires_grad=False, + ) + self.all_up_scales = torch.nn.Parameter( + torch.stack([exp.up_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // self.group_size + ), + requires_grad=False, + ) + self.all_up_qzeros = torch.nn.Parameter( + torch.stack([exp.up_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_up, self.in_features_up // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_up_gidx = torch.nn.Parameter( + torch.stack([exp.up_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + self.all_down_qweight = torch.nn.Parameter( + torch.stack([exp.down_proj.qweight for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // 2 + ), + requires_grad=False, + ) + self.all_down_scales = torch.nn.Parameter( + torch.stack([exp.down_proj.scales for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // self.group_size + ), + requires_grad=False, + ) + self.all_down_qzeros = torch.nn.Parameter( + torch.stack([exp.down_proj.qzeros for exp in self.experts], dim=0).reshape( + -1, self.out_features_down, self.in_features_down // (self.group_size * 2) + ), + requires_grad=False, + ) + self.all_down_gidx = torch.nn.Parameter( + torch.stack([exp.down_proj.g_idx for exp in self.experts], dim=0), requires_grad=False + ) + + # self.fc1_experts_weights = torch.nn.Parameter(all_gate_qweight.view(self.config.n_routed_experts, self.out_features_gate, -1), requires_grad=False) + # self.fc1_scales = torch.nn.Parameter(all_gate_scales.view(self.config.n_routed_experts, self.out_features_gate, -1), requires_grad=False) + + # self.fc2_experts_weights = torch.nn.Parameter(all_up_qweight.view(self.config.n_routed_experts, self.out_features_up, -1), requires_grad=False) + # self.fc2_scales = torch.nn.Parameter(all_up_scales.view(self.config.n_routed_experts, self.out_features_up, -1), requires_grad=False) + + # self.fc3_experts_weights = torch.nn.Parameter(all_down_qweight.view(self.config.n_routed_experts, self.out_features_down, -1), requires_grad=False) + # self.fc3_scales = torch.nn.Parameter(all_down_scales.view(self.config.n_routed_experts, self.out_features_down, -1), requires_grad=False) + + def moe_old( self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, @@ -773,49 +897,214 @@ def moe( hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - gate_proj = self.all_gate_proj[topk_indices.flatten()] - up_proj = self.all_up_proj[topk_indices.flatten()] - down_proj = self.all_down_proj[topk_indices.flatten()] + for i in range(self.gate.top_k): + expert_idx = topk_indices[:, i] + curr_weight = topk_weights[:, i] + gate_qweight = self.all_gate_qweight[expert_idx].reshape( + seq_len * self.out_features_gate, + self.in_features_gate // self.group_size, + (self.group_size * self.bits) // 8, + ) + gate_scales = self.all_gate_scales[expert_idx].reshape( + seq_len * self.out_features_gate * (self.in_features_gate // self.group_size) + ) + gate_qzeros = self.all_gate_qzeros[expert_idx].reshape( + seq_len * self.out_features_gate, self.in_features_gate // self.group_size + ) + gate_gidx = self.all_gate_gidx[expert_idx].reshape(seq_len * self.in_features_gate) + + up_qweight = self.all_up_qweight[expert_idx].reshape( + seq_len * self.out_features_up, + self.in_features_up // self.group_size, + (self.group_size * self.bits) // 8, + ) + up_scales = self.all_up_scales[expert_idx].reshape( + seq_len * self.out_features_up * (self.in_features_up // self.group_size) + ) + up_qzeros = self.all_up_qzeros[expert_idx].reshape( + seq_len * self.out_features_up, self.in_features_up // self.group_size + ) + up_gidx = self.all_up_gidx[expert_idx].reshape(seq_len * self.in_features_up) + + down_qweight = self.all_down_qweight[expert_idx].reshape( + seq_len * self.out_features_down, + self.in_features_down // self.group_size, + (self.group_size * self.bits) // 8, + ) + down_scales = self.all_down_scales[expert_idx].reshape( + seq_len * self.out_features_down * (self.in_features_down // self.group_size) + ) + down_qzeros = self.all_down_qzeros[expert_idx].reshape( + seq_len * self.out_features_down, self.in_features_down // self.group_size + ) + down_gidx = self.all_down_gidx[expert_idx].reshape(seq_len * self.in_features_down) + + gate_out = QuantLinearTorchFunction.apply( + hidden_states, + gate_qweight, + gate_scales, + gate_qzeros, + gate_gidx if self.act_order else None, + self.bits, + self.group_size, + self.in_features_gate, + self.out_features_gate * seq_len, + ) + + up_out = QuantLinearTorchFunction.apply( + hidden_states, + up_qweight, + up_scales, + up_qzeros, + up_gidx if self.act_order else None, + self.bits, + self.group_size, + self.in_features_up, + self.out_features_up * seq_len, + ) + + hidden = self.act_fn(gate_out) * up_out + down_out = QuantLinearTorchFunction.apply( + hidden, + down_qweight, + down_scales, + down_qzeros, + down_gidx if self.act_order else None, + self.bits, + self.group_size, + self.in_features_down, + self.out_features_down, + ) + down_out = down_out.reshape(seq_len, self.out_features_down) + final_hidden_states += down_out * curr_weight.unsqueeze(1) + + return final_hidden_states + + def moe_weights_as_activations(self, hidden_states, router_probs, router_weights): + + return QMOE.apply( + hidden_states, + router_weights, + self.fc1_experts_weights, + self.fc1_scales, + self.fc2_experts_weights, + self.fc2_scales, + self.fc3_experts_weights, + self.fc3_scales, + router_probs, + self.config.hidden_act, + self.group_size, + self.bits, + self.num_experts_per_tok, + ) + + @torch.no_grad() + def original_moe(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + def moe_waa_unpack(self, hidden_states, topk_indices, topk_weights): + # GATHER - collect weights for selected experts + gate_proj_qweight = self.all_gate_qweight[topk_indices.flatten()] + gate_proj_scales = self.all_gate_scales[topk_indices.flatten()] + gate_proj_qzeros = self.all_gate_qzeros[topk_indices.flatten()] + + up_proj_qweight = self.all_up_qweight[topk_indices.flatten()] + up_proj_scales = self.all_up_scales[topk_indices.flatten()] + up_proj_qzeros = self.all_up_qzeros[topk_indices.flatten()] + + down_proj_qweight = self.all_down_qweight[topk_indices.flatten()] + down_proj_scales = self.all_down_scales[topk_indices.flatten()] + down_proj_qzeros = self.all_down_qzeros[topk_indices.flatten()] + + gate_proj_unpacked = CastToUInt4Func.apply(gate_proj_qweight) + gate_zeros_unpacked = CastToUInt4Func.apply(gate_proj_qzeros) + gate_proj_dq = DequantizeLinearFunc.apply( + gate_proj_unpacked, gate_proj_scales, gate_zeros_unpacked, self.group_size + ) + + up_proj_unpacked = CastToUInt4Func.apply(up_proj_qweight) + up_zeros_unpacked = CastToUInt4Func.apply(up_proj_qzeros) + up_proj_dq = DequantizeLinearFunc.apply(up_proj_unpacked, up_proj_scales, up_zeros_unpacked, self.group_size) + + down_proj_unpacked = CastToUInt4Func.apply(down_proj_qweight) + down_zeros_unpacked = CastToUInt4Func.apply(down_proj_qzeros) + down_proj_dq = DequantizeLinearFunc.apply( + down_proj_unpacked, down_proj_scales, down_zeros_unpacked, self.group_size + ) + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( - hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size) + hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.in_features_gate) ) - gate_out = torch.bmm(expert_in, gate_proj) - up_out = torch.bmm(expert_in, up_proj) + gate_out = torch.bmm(expert_in, gate_proj_dq.transpose(1, 2)) + up_out = torch.bmm(expert_in, up_proj_dq.transpose(1, 2)) hidden = self.act_fn(gate_out) * up_out - expert_output = torch.bmm(hidden, down_proj) - experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) - experts_out = experts_out * topk_weights.unsqueeze(-1) + down_out = torch.bmm(hidden, down_proj_dq.transpose(1, 2)) - final_hidden_states = torch.einsum("abc->ac", experts_out) + down_out = down_out.view(-1, self.gate.top_k, self.out_features_down) - return final_hidden_states.type(hidden_states.dtype) + down_out = down_out * topk_weights.unsqueeze(-1) + + return torch.einsum("abc-> ac", down_out) def forward(self, hidden_states): + print("Using new MoE forward with weights as activations") residuals = hidden_states orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) + topk_indices, topk_weights, router_probs, router_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = self.moe_waa_unpack(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states class QEffPrefillOnlyDeepseekV3MoE(nn.Module): - def __qeff_init__( - self, - ): - for exp in self.experts: - gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) - down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) - - gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) - up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) - down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) - - setattr(exp, "gate_proj", gate_proj) - setattr(exp, "up_proj", up_proj) - setattr(exp, "down_proj", down_proj) + # def __qeff_init__( + # self, + # ): + # for exp in self.experts: + # gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + # up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + # down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + # + # gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) + # up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) + # down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + # + # setattr(exp, "gate_proj", gate_proj) + # setattr(exp, "up_proj", up_proj) + # setattr(exp, "down_proj", down_proj) def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) @@ -828,6 +1117,7 @@ def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_ma current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) final_hidden_states += current_hidden_states + print("\n\ninside prefill only moe\n") return final_hidden_states.type(hidden_states.dtype) def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): @@ -862,7 +1152,7 @@ def forward(self, hidden_states): """ residuals = hidden_states orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) + topk_indices, topk_weights, _, _ = self.gate(hidden_states) # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4820715d7e..456e7fd4c5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -70,6 +70,7 @@ FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, Mxfp4GptOssExpertDequantizeTransform, + PackQuantizedInt4ToMatMulNBitsTransform, ) from QEfficient.utils import ( constants, @@ -2757,6 +2758,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): GPTQToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, Mxfp4GptOssExpertDequantizeTransform, + PackQuantizedInt4ToMatMulNBitsTransform, CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5ff06e6443..9bd151f7f7 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -250,6 +250,7 @@ QEffDisentangledSelfAttention, ) from QEfficient.transformers.models.deepseek_v3.modeling_deepseek import ( + QEffDeepseekMoEGate, QEffDeepseekV3Attention, QEffDeepseekV3CustomRMSNormAIC, QEffDeepseekV3DecoderLayer, @@ -1034,8 +1035,11 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): }, "DeepseekV3MoE": { "forward": QEffDeepseekV3MoE.forward, - "moe": QEffDeepseekV3MoE.moe, + "moe_weights_as_activations": QEffDeepseekV3MoE.moe_weights_as_activations, + "moe_waa_unpack": QEffDeepseekV3MoE.moe_waa_unpack, + "original_moe": QEffDeepseekV3MoE.original_moe, "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, + "gate.forward": QEffDeepseekMoEGate.forward, }, "DeepseekV3Attention": { "forward": QEffDeepseekV3Attention.forward, diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index f97bfe998e..76a242e257 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- import torch +from compressed_tensors.compressors import PackedQuantizationCompressor +from compressed_tensors.linear.compressed_linear import CompressedLinear from torch import nn from transformers import AutoConfig from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts @@ -67,6 +69,43 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module): return new_module +class PackQuantizedInt4ToMatMulNBitsTransform(ModuleMutatorTransform): + """ + This transform is used to pack the quantized int4 weights into a format that can be used by the MatMulNBits kernel. + It is used for the ONNX export of the quantized model. + """ + + _match_class = CompressedLinear + + @classmethod + def mutate(cls, original_module, parent_module): + # add compressor.decompress to get the decompressed weight + # and then package into matmulnbit + assert isinstance(original_module.compressor, PackedQuantizationCompressor), ( + f"Only {PackedQuantizationCompressor} supported for now" + ) + fp_weight = original_module.compressor.decompress_module(original_module) + scales = original_module.weight_scale + # assuming symmetric quantization + quantization_args = original_module.quantization_scheme.weights + zeros = (torch.zeros_like(scales) + pow(2, (quantization_args.num_bits - 1))).to(torch.uint8) + g_idx = torch.arange(original_module.in_features // quantization_args.group_size).repeat_interleave( + quantization_args.group_size + ) + original_module.weight = torch.nn.Parameter(fp_weight) + assert quantization_args.type == "int", "uint is not tested yet" + new_module = QuantLinearORT( + quantization_args.num_bits, + quantization_args.group_size, + original_module.in_features, + original_module.out_features, + original_module.bias is not None, + ) + new_module.bias = original_module.bias if original_module.bias is not None else None + new_module.pack(original_module, scales, zeros, g_idx) + return new_module + + class GPTQToMatmulNbitsTransform(ModuleMutatorTransform): """ A transformation class that mutates a ``QuantLinearGPTQ`` module to a ``QuantLinearORT`` diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py index f7ecc5b218..f20e350f64 100644 --- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py +++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- + from dataclasses import dataclass from enum import Enum from typing import List @@ -12,6 +13,7 @@ import torch from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer +from transformers.utils import is_compressed_tensors_available from transformers.utils.quantization_config import CompressedTensorsConfig, QuantizationConfigMixin, QuantizationMethod from QEfficient.transformers.quantizers.quantizer_utils import blockwise_dequantize, get_keys_to_not_convert @@ -390,7 +392,55 @@ def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str class QEffCompressedTensorsConfig(CompressedTensorsConfig): - def __init__( + def handle_pack_quantized_init( + self, + config_groups=None, + format="dense", + quantization_status="initialized", + kv_cache_scheme=None, + global_compression_ratio=None, + ignore=None, + sparsity_config=None, + quant_method="compressed-tensors", + run_compressed: bool = True, + **kwargs, + ): + if is_compressed_tensors_available(): + from compressed_tensors.config import SparsityCompressionConfig + from compressed_tensors.quantization import QuantizationConfig + else: + raise ImportError( + "compressed_tensors is not installed and is required for compressed-tensors quantization. Please install it with `pip install compressed-tensors`." + ) + self.quantization_config = None + self.sparsity_config = None + + self.run_compressed = run_compressed + assert self.run_compressed, "pack-quantized needs to have run_compressed set to True" + + # parse from dict to load nested QuantizationScheme objects + if config_groups or kv_cache_scheme: + self.quantization_config = QuantizationConfig.model_validate( + { + "config_groups": config_groups, + "quant_method": quant_method, + "format": format, + "quantization_status": quantization_status, + "kv_cache_scheme": kv_cache_scheme, + "global_compression_ratio": global_compression_ratio, + "ignore": ignore, + **kwargs, + } + ) + + if sparsity_config: + self.sparsity_config = SparsityCompressionConfig.load_from_registry( + sparsity_config.get("format"), **sparsity_config + ) + + self.quant_method = QuantizationMethod.COMPRESSED_TENSORS + + def handle_fp8_init( self, config_groups=None, format="dense", @@ -480,7 +530,49 @@ def __init__( self.quant_method = QuantizationMethod.COMPRESSED_TENSORS + def __init__( + self, + config_groups=None, + format="dense", + quantization_status="initialized", + kv_cache_scheme=None, + global_compression_ratio=None, + ignore=None, + sparsity_config=None, + quant_method="compressed-tensors", + run_compressed: bool = None, + **kwargs, + ): + if format == "pack-quantized": + self.handle_pack_quantized_init( + config_groups=config_groups, + format=format, + quantization_status=quantization_status, + kv_cache_scheme=kv_cache_scheme, + global_compression_ratio=global_compression_ratio, + ignore=ignore, + sparsity_config=sparsity_config, + quant_method=quant_method, + run_compressed=True if run_compressed is None else run_compressed, + **kwargs, + ) + else: + self.handle_fp8_init( + config_groups=config_groups, + format=format, + quantization_status=quantization_status, + kv_cache_scheme=kv_cache_scheme, + global_compression_ratio=global_compression_ratio, + ignore=ignore, + sparsity_config=sparsity_config, + quant_method=quant_method, + run_compressed=False if run_compressed is None else run_compressed, + **kwargs, + ) + def to_dict(self): + if self.quantization_config.format == "pack-quantized": + return super().to_dict() return { "quantization_config": { "config_groups": self.config_groups, @@ -501,39 +593,56 @@ def to_dict(self): class QEffCompressedTensorsFP8Quantizer(CompressedTensorsHfQuantizer): requires_calibration = False - def __init__(self, quantization_config, **kwargs): - # TODO: check if more checks are required - if not isinstance(quantization_config, QEffCompressedTensorsConfig): - raise TypeError( - f"Only {QEffCompressedTensorsConfig} is supported for initialization got {type(quantization_config)}" - ) - self.run_compressed = quantization_config.run_compressed - self.quantization_config = quantization_config - - # -- Handle extra kwargs below -- - self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) - self.modules_to_not_convert = list( - set(self.modules_to_not_convert if self.modules_to_not_convert else []) - | set(self.quantization_config.ignore if self.quantization_config.ignore else []) + @staticmethod + def is_pack_quantized(quant_config): + return ( + hasattr(quant_config, "quantization_config") + and hasattr(quant_config.quantization_config, "format") + and quant_config.quantization_config.format == "pack-quantized" ) - self.pre_quantized = kwargs.pop("pre_quantized", True) - if not self.pre_quantized and self.requires_calibration: - raise ValueError( - f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." - f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " - f"pass `pre_quantized=True` while knowing what you are doing." + def __init__(self, quantization_config, **kwargs): + if self.is_pack_quantized(quantization_config): + super().__init__(quantization_config, **kwargs) + else: + if not isinstance(quantization_config, QEffCompressedTensorsConfig): + raise TypeError( + f"Only {QEffCompressedTensorsConfig} is supported for initialization got {type(quantization_config)}" + ) + self.run_compressed = quantization_config.run_compressed + self.quantization_config = quantization_config + + # -- Handle extra kwargs below -- + self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) + self.modules_to_not_convert = list( + set(self.modules_to_not_convert if self.modules_to_not_convert else []) + | set(self.quantization_config.ignore if self.quantization_config.ignore else []) ) + self.pre_quantized = kwargs.pop("pre_quantized", True) + + if not self.pre_quantized and self.requires_calibration: + raise ValueError( + f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." + f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " + f"pass `pre_quantized=True` while knowing what you are doing." + ) def validate_environment(self, *args, **kwargs): + if self.is_pack_quantized(self.quantization_config): + return super().validate_environment(*args, **kwargs) return True def update_torch_dtype(self, torch_dtype): + if self.is_pack_quantized(self.quantization_config): + return super().update_torch_dtype(torch_dtype) if torch_dtype not in [None, torch.float32]: logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") return None def _process_model_before_weight_loading(self, model, **kwargs): + if self.is_pack_quantized(self.quantization_config): + super()._process_model_before_weight_loading(model, **kwargs) + return if self.quantization_config.targets != ["Linear"]: raise NotImplementedError( f"Only Linear layer with FP8 quantization are supported got targets = {self.quantization_config.targets}" @@ -561,10 +670,17 @@ def replace_linear_with_fp8_dequant_layer(module): replace_linear_with_fp8_dequant_layer(model) def _process_model_after_weight_loading(self, model, **kwargs): + if self.is_pack_quantized(self.quantization_config): + super()._process_model_after_weight_loading(model, **kwargs) + return pass def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]: + if self.is_pack_quantized(self.quantization_config): + return super().update_missing_keys_after_loading(model, missing_keys=missing_keys, prefix=prefix) return missing_keys def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str = None) -> List[str]: + if self.is_pack_quantized(self.quantization_config): + return super().update_unexpected_keys(model, unexpected_keys=unexpected_keys, prefix=prefix) return unexpected_keys From 7bbeed15da9dcab454e0cad92a3fe4e36a3fd7f2 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 28 Apr 2026 14:14:03 +0530 Subject: [PATCH 2/5] add prefill changes Signed-off-by: Mamta Singh --- .../blocking/blocked_attention_forwards.py | 30 ++- QEfficient/customop/ctx_scatter_gather.py | 7 +- .../models/deepseek_v3/modeling_deepseek.py | 181 +++++++++++++++--- .../transformers/models/pytorch_transforms.py | 2 +- 4 files changed, 184 insertions(+), 36 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 6b03286503..c4f64db090 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -841,7 +841,9 @@ def blocked_kv_mla_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize result tensor batch_size, num_heads, seq_len, _ = query.shape - output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype) + output = torch.zeros( + batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype + ) if hasattr(module, "config"): mask_dtype = module.config.torch_dtype @@ -897,17 +899,23 @@ def blocked_kv_mla_attention_forward( k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] num_heads_to_repeat = q_heads - k_heads - repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank) + repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.kv_lora_rank + ) compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) - repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim) + repeated_k_pe_block = k_pe_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + ) k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) if absorption: krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) k_heads, q_heads = krope_nope.shape[1], query.shape[1] num_heads_to_repeat = q_heads - k_heads - repeated_k = krope_nope[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank) + repeated_k = krope_nope[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank + ) krope_nope = torch.cat((krope_nope, repeated_k), dim=1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] @@ -924,13 +932,17 @@ def blocked_kv_mla_attention_forward( else: k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] num_heads_to_repeat = q_heads - k_heads - repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank) + repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.kv_lora_rank + ) compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) - - repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim) + + repeated_k_pe_block = k_pe_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + ) k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) - + krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) @@ -944,7 +956,6 @@ def blocked_kv_mla_attention_forward( skip_future, ) - attn_output = torch.matmul(output, per_head_v_up) attn_output = attn_output.transpose(1, 2).contiguous() attn_weights = None @@ -995,6 +1006,7 @@ def blocked_h_mla_attention_forward( h_output_blocks = [] h_attn_blocks = [] + # Process each head block independently for head_block_idx in range(num_head_blocks): h_start = head_block_idx * head_block_size diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 59bfe6af03..4f46791af8 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates # Create indices batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + + # keep index tensor types aligned for backend that require exact dtype match + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) indices = ops.Concat(batch_idx, ctx_idx, axis=2) @@ -78,8 +81,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates class CtxScatterFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() batch_idx = torch.arange(data.shape[0]).view(-1, 1) - ctx_idx = position_ids + ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids) data[batch_idx, ctx_idx] = updates return data @@ -103,6 +107,7 @@ class CtxGatherFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, ctx_indices] @staticmethod diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 6c3bd0c083..201cc06600 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -20,6 +20,7 @@ generic_blocked_attention_interface, generic_blocked_mla_attention_interface, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.customop.matmulnbits import QMOE, QuantLinearTorchFunction from QEfficient.customop.quantization_ops import CastToUInt4Func, DequantizeLinearFunc from QEfficient.customop.rms_norm import CustomRMSNormFunc @@ -742,6 +743,42 @@ def forward( ) +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) + + +def _ctx_scatter_gather_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + act_fn, + T: int, +) -> torch.Tensor: + """Packed-prefix expert helper for NSP-blocked dispatch.""" + batch_size, hidden_size = T2Ei.shape[0], x.shape[1] + scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) + invalid_mask = ~T2Ei + INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + + x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) + x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + + gate_prime = x_prime @ W_g + up_prime = x_prime @ W_u + down_prime = (up_prime * act_fn(gate_prime)) @ W_d + + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) + down_prime = torch.where((row_range < valid_rows).unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + + gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) + delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) + return delta_out + + class QEffDeepseekMoEGate(nn.Module): def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape @@ -1106,6 +1143,47 @@ class QEffPrefillOnlyDeepseekV3MoE(nn.Module): # setattr(exp, "up_proj", up_proj) # setattr(exp, "down_proj", down_proj) + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] + self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] + self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = EXPERT_BLOCKING_NUM_NSP + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :].unsqueeze(-1) + T2Ei = routing_weight.squeeze(-1) > 0 + delta = _ctx_scatter_gather_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + act_fn=self.experts[0].act_fn, + T=T, + ) + expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) + def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) for expert_idx in range(num_experts): @@ -1146,31 +1224,84 @@ def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk # and all expert are "local" meaning we shard but we don't gather return final_hidden_states.type(hidden_states.dtype) - def forward(self, hidden_states): - """ - Forward pass of MoE block. - """ - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights, _, _ = self.gate(hidden_states) - # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) - mask.scatter_(1, topk_indices, topk_weights) - if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: - hidden_states = self.moe_blocked_weights_forward( - hidden_states, topk_weights, mask, self.config.n_routed_experts - ).view(*orig_shape) - elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: - hidden_states = self.moe_blocked_forward( - hidden_states, topk_weights, mask, self.config.n_routed_experts - ).view(*orig_shape) - else: - hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) - - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states + def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) # [T, E] + prob = F.softmax(router_logits, -1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, -1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + routing_weights = masked_logits + expert_out = x.new_zeros((T, H)) + for e in range(self.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight + return expert_out.view(B, S, H), router_logits + + # def forward(self, hidden_states): + # """ + # Forward pass of MoE block. + # """ + # residuals = hidden_states + # orig_shape = hidden_states.shape + # topk_indices, topk_weights, _, _ = self.gate(hidden_states) + # # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + # + # hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + # mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) + # mask.scatter_(1, topk_indices, topk_weights) + # + # if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: + # hidden_states = self.moe_blocked_weights_forward( + # hidden_states, topk_weights, mask, self.config.n_routed_experts + # ).view(*orig_shape) + # elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: + # hidden_states = self.moe_blocked_forward( + # hidden_states, topk_weights, mask, self.config.n_routed_experts + # ).view(*orig_shape) + # else: + # hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) + # + # hidden_states = hidden_states + self.shared_experts(residuals) + # return hidden_states + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) + prob = F.softmax(router_logits, -1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, -1) + if self.norm_topk_prob: + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) + + if self.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + + expert_out = x.new_zeros((T, H)) + for e in range(self.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight + return expert_out.view(B, S, H), router_logits class QEffDeepseekV3DecoderLayer(nn.Module): diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 9bd151f7f7..60c92614ae 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1073,7 +1073,7 @@ class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransfo _match_string_replace_method = { "DeepseekV3MoE": { "forward": QEffDeepseekV3MoE.forward, - "moe": QEffDeepseekV3MoE.moe, + "moe": QEffDeepseekV3MoE.moe_waa_unpack, "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, }, } From 9201589440d1004cf1e763d98c2f3f504d00a5d2 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 28 Apr 2026 15:17:10 +0530 Subject: [PATCH 3/5] fix modeling Signed-off-by: Mamta Singh --- .../models/deepseek_v3/modeling_deepseek.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 201cc06600..b87e0e34cc 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1148,10 +1148,10 @@ def __qeff_init__(self): self.up_proj_w = [] self.down_proj_w = [] with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) + for e in range(len(self.experts)): + self.gate_proj_w.append(self.experts[e].gate_proj.qweight.T) + self.up_proj_w.append(self.experts[e].up_proj.qweight.T) + self.down_proj_w.append(self.experts[e].down_proj.qweight.T) self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] @@ -1159,11 +1159,11 @@ def __qeff_init__(self): def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP - if self.num_experts % num_nsp != 0: + if len(self.experts) % num_nsp != 0: raise ValueError( - f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + f"num_experts ({len(self.experts)}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" ) - local_experts = self.num_experts // num_nsp + local_experts = len(self.experts) // num_nsp rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() @@ -1238,10 +1238,10 @@ def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch masked_logits.scatter_(1, top_i, top_w) routing_weights = masked_logits expert_out = x.new_zeros((T, H)) - for e in range(self.num_experts): + for e in range(len(self.experts)): routing_weight = routing_weights[:, e].unsqueeze(-1) - W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T - W_d = self.experts[e].down_proj.weight.T + W_g, W_u = self.experts[e].gate_proj.qweight.T, self.experts[e].up_proj.qweight.T + W_d = self.experts[e].down_proj.qweight.T gate = x @ W_g up = x @ W_u down = (up * self.experts[e].act_fn(gate)) @ W_d @@ -1288,15 +1288,15 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights = torch.zeros_like(router_logits) routing_weights.scatter_(1, top_i, top_w) - if self.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + if len(self.experts) % EXPERT_BLOCKING_NUM_NSP == 0: expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) return expert_out.view(B, S, H), router_logits expert_out = x.new_zeros((T, H)) - for e in range(self.num_experts): + for e in range(len(self.experts)): routing_weight = routing_weights[:, e].unsqueeze(-1) - W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T - W_d = self.experts[e].down_proj.weight.T + W_g, W_u = self.experts[e].gate_proj.qweight.T, self.experts[e].up_proj.qweight.T + W_d = self.experts[e].down_proj.qweight.T gate = x @ W_g up = x @ W_u down = (up * self.experts[e].act_fn(gate)) @ W_d From e57dbc0a519e331dcfde775668e712668ab8e014 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 28 Apr 2026 15:36:23 +0530 Subject: [PATCH 4/5] prefill fix Signed-off-by: Mamta Singh --- .../transformers/models/modeling_auto.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a106c6d0d9..69431502df 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3100,14 +3100,16 @@ def export( ) self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) self.hash_params.pop("retain_full_kv", None) - seq_len = self.get_seq_len_and_handle_specialized_prefill_model( - prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking - ) - kv_cache_shape[2] = ( - seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) - if enable_chunking - else seq_len - ) + if "DeepseekV3ForCausalLM" not in (getattr(self.model.config, "architectures", None) or []): + seq_len = self.get_seq_len_and_handle_specialized_prefill_model( + prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + ) + kv_cache_shape[2] = ( + seq_len + + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + if enable_chunking + else seq_len + ) else: self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) From f499be872cb09f347d887f362d87125ca44a3bc6 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 29 Apr 2026 15:54:03 +0530 Subject: [PATCH 5/5] pushing local changes not working yet Signed-off-by: Onkar Chougule --- .../models/deepseek_v3/modeling_deepseek.py | 86 +++++++------------ .../transformers/models/modeling_auto.py | 17 ++-- .../transformers/models/pytorch_transforms.py | 2 +- 3 files changed, 39 insertions(+), 66 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index b87e0e34cc..fd7cdf1bf5 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -1143,18 +1143,19 @@ class QEffPrefillOnlyDeepseekV3MoE(nn.Module): # setattr(exp, "up_proj", up_proj) # setattr(exp, "down_proj", down_proj) - def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(len(self.experts)): - self.gate_proj_w.append(self.experts[e].gate_proj.qweight.T) - self.up_proj_w.append(self.experts[e].up_proj.qweight.T) - self.down_proj_w.append(self.experts[e].down_proj.qweight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] - self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] - self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] + # def __qeff_init__(self): + # import ipdb; ipdb.set_trace() + # self.gate_proj_w = [] + # self.up_proj_w = [] + # self.down_proj_w = [] + # with torch.no_grad(): + # for e in range(len(self.experts)): + # self.gate_proj_w.append(self.experts[e].gate_proj.qweight.T) + # self.up_proj_w.append(self.experts[e].up_proj.qweight.T) + # self.down_proj_w.append(self.experts[e].down_proj.qweight.T) + # self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] + # self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] + # self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: T, H = x.shape @@ -1248,60 +1249,31 @@ def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch expert_out += down * routing_weight return expert_out.view(B, S, H), router_logits - # def forward(self, hidden_states): - # """ - # Forward pass of MoE block. - # """ - # residuals = hidden_states - # orig_shape = hidden_states.shape - # topk_indices, topk_weights, _, _ = self.gate(hidden_states) - # # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - # - # hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - # mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) - # mask.scatter_(1, topk_indices, topk_weights) - # - # if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: - # hidden_states = self.moe_blocked_weights_forward( - # hidden_states, topk_weights, mask, self.config.n_routed_experts - # ).view(*orig_shape) - # elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: - # hidden_states = self.moe_blocked_forward( - # hidden_states, topk_weights, mask, self.config.n_routed_experts - # ).view(*orig_shape) - # else: - # hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) - # - # hidden_states = hidden_states + self.shared_experts(residuals) - # return hidden_states def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + topk_idx, topk_weight, router_probs, router_weights = self.gate(hidden_states) B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - router_logits = self.gate(x) - prob = F.softmax(router_logits, -1, dtype=torch.float) - top_w, top_i = torch.topk(prob, self.top_k, -1) - if self.norm_topk_prob: - top_w /= top_w.sum(-1, keepdim=True) - top_w = top_w.to(hidden_states.dtype) - routing_weights = torch.zeros_like(router_logits) - routing_weights.scatter_(1, top_i, top_w) + + routing_weights = torch.zeros(T, self.config.n_routed_experts) + routing_weights.scatter_(1, topk_idx, topk_weight) if len(self.experts) % EXPERT_BLOCKING_NUM_NSP == 0: expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) - return expert_out.view(B, S, H), router_logits + return expert_out.view(B, S, H) - expert_out = x.new_zeros((T, H)) - for e in range(len(self.experts)): - routing_weight = routing_weights[:, e].unsqueeze(-1) - W_g, W_u = self.experts[e].gate_proj.qweight.T, self.experts[e].up_proj.qweight.T - W_d = self.experts[e].down_proj.qweight.T - gate = x @ W_g - up = x @ W_u - down = (up * self.experts[e].act_fn(gate)) @ W_d - expert_out += down * routing_weight - return expert_out.view(B, S, H), router_logits + final_hidden_states = x.new_zeros((T, H)) + for expert_idx in range(self.n_routed_experts): + expert = self.experts[expert_idx] + gate_out = expert.gate_proj(hidden_states) + up_out = expert.up_proj(hidden_states) + hidden = expert.act_fn(gate_out) * up_out + expert_output = expert.down_proj(hidden) + current_hidden_states = expert_output * routing_weights[:, expert_idx].unsqueeze(-1) + final_hidden_states += current_hidden_states + + return final_hidden_states.view(B, S, H) class QEffDeepseekV3DecoderLayer(nn.Module): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 69431502df..ed9528379f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3100,13 +3100,13 @@ def export( ) self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) self.hash_params.pop("retain_full_kv", None) - if "DeepseekV3ForCausalLM" not in (getattr(self.model.config, "architectures", None) or []): + + if self.model.config.model_type == "gpt_oss": seq_len = self.get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking ) kv_cache_shape[2] = ( - seq_len - + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) if enable_chunking else seq_len ) @@ -3117,11 +3117,12 @@ def export( self.hash_params.pop("NUM_FFN_BLOCKS", None) self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) - if kwargs.get("retain_full_kv", False): - kv_cache_shape[2] = seq_len + ( - self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 - ) - self.hash_params["retain_full_kv"] = True + if self.model.config.model_type == "gpt_oss": + if kwargs.get("retain_full_kv", False): + kv_cache_shape[2] = seq_len + ( + self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 + ) + self.hash_params["retain_full_kv"] = True example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 60c92614ae..e5199d23eb 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1063,7 +1063,7 @@ class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): "DeepseekV3MoE": { "forward": QEffPrefillOnlyDeepseekV3MoE.forward, "moe": QEffPrefillOnlyDeepseekV3MoE.moe, - "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + # "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, }, }