Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CtxScatterFuncCB,
CtxScatterFuncCB3D,
)
from QEfficient.customop.quantization_ops import CastToUInt4, CastToUInt4Func
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
from QEfficient.utils.constants import FILE_CHUNK_SIZE_DEFAULT, ONNX_EXPORT_OPSET, SIZE_THRESHOLD_DEFAULT

Expand Down Expand Up @@ -100,6 +101,7 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CastToUInt4": (CastToUInt4Func, CastToUInt4),
}

@classmethod
Expand Down
13 changes: 11 additions & 2 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from types import MethodType
from typing import Callable, Dict, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type

from torch import nn

Expand Down Expand Up @@ -97,6 +98,7 @@ class ModuleMutatorTransform(PytorchTransform):
"""

_match_class: nn.Module
_match_string: Optional[str] = None

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
Expand Down Expand Up @@ -135,7 +137,14 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__)
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
parts = orig_method_name.split(".")
if len(parts) > 1:
target = module
for part in parts[:-1]:
target = getattr(target, part)
setattr(target, parts[-1], MethodType(mapped_method, target))
else:
setattr(module, orig_method_name, MethodType(mapped_method, module))

if hasattr(module, "__qeff_init__"):
module.__qeff_init__()
Expand Down
30 changes: 21 additions & 9 deletions QEfficient/blocking/blocked_attention_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,9 @@ def blocked_kv_mla_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Initialize result tensor
batch_size, num_heads, seq_len, _ = query.shape
output = torch.zeros(batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype)
output = torch.zeros(
batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype
)

if hasattr(module, "config"):
mask_dtype = module.config.torch_dtype
Expand Down Expand Up @@ -897,17 +899,23 @@ def blocked_kv_mla_attention_forward(

k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1]
num_heads_to_repeat = q_heads - k_heads
repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank)
repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand(
batch_size, num_heads_to_repeat, -1, module.kv_lora_rank
)
compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1)

repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim)
repeated_k_pe_block = k_pe_block[:, 0, :, :].expand(
batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim
)
k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1)

if absorption:
krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1)
k_heads, q_heads = krope_nope.shape[1], query.shape[1]
num_heads_to_repeat = q_heads - k_heads
repeated_k = krope_nope[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank)
repeated_k = krope_nope[:, 0, :, :].expand(
batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank
)
krope_nope = torch.cat((krope_nope, repeated_k), dim=1)
attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling
# [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size]
Expand All @@ -924,13 +932,17 @@ def blocked_kv_mla_attention_forward(
else:
k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1]
num_heads_to_repeat = q_heads - k_heads
repeated_ckv_block = compressed_kv_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.kv_lora_rank)
repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand(
batch_size, num_heads_to_repeat, -1, module.kv_lora_rank
)
compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1)
knope = torch.matmul(compressed_kv_block, per_head_k_up_normal)

repeated_k_pe_block = k_pe_block[:, 0,:,:].expand(batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim)

repeated_k_pe_block = k_pe_block[:, 0, :, :].expand(
batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim
)
k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1)

krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1)
attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling
attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block)
Expand All @@ -944,7 +956,6 @@ def blocked_kv_mla_attention_forward(
skip_future,
)


attn_output = torch.matmul(output, per_head_v_up)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_weights = None
Expand Down Expand Up @@ -995,6 +1006,7 @@ def blocked_h_mla_attention_forward(

h_output_blocks = []
h_attn_blocks = []

# Process each head block independently
for head_block_idx in range(num_head_blocks):
h_start = head_block_idx * head_block_size
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape)

# keep index tensor types aligned for backend that require exact dtype match
batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape)
indices = ops.Concat(batch_idx, ctx_idx, axis=2)

Expand All @@ -78,8 +81,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates
class CtxScatterFunc3D(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
data = data.clone()
batch_idx = torch.arange(data.shape[0]).view(-1, 1)
ctx_idx = position_ids
ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids)
data[batch_idx, ctx_idx] = updates
return data

Expand All @@ -103,6 +107,7 @@ class CtxGatherFunc3D(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1)
ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
return data[batch_indices, ctx_indices]

@staticmethod
Expand Down
88 changes: 81 additions & 7 deletions QEfficient/customop/matmulnbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
class QuantLinearTorchFunction(torch.autograd.Function):
@staticmethod
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
if qself_qzeros is None:
input_tuple = (x, qself_qweight, qself_scales)
else:
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
input_tuple += (g_idx,) if g_idx is not None else ()
return g.op(
"com.microsoft::MatMulNBits",
Expand All @@ -28,7 +31,10 @@ def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group

@staticmethod
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
if qself_qzeros is None:
qself_qzeros = 2 ^ (bits - 1)
if torch.onnx.is_in_onnx_export():
# For faster export
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float()
fp_weight = dequantize_blockwise_bits(
qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features
Expand All @@ -40,8 +46,7 @@ def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, grou
def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols):
if bits != 4:
raise ValueError("Only bits=4 is supported for executing quantized model")
if group_size != 128:
raise ValueError("Only group_size=128 is supported for executing quantized model")

expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
Expand Down Expand Up @@ -88,20 +93,20 @@ def __init__(self, bits, group_size, in_features, out_features, bias):
q_rows = in_features // self.group_size
self.register_buffer(
"qweight",
torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
torch.empty((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
)
self.register_buffer(
"qzeros",
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
torch.empty((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
)
self.register_buffer(
"scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
"scales", torch.empty((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
)
self.register_buffer(
"g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32)
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
self.register_buffer("bias", torch.empty((out_features), dtype=torch.float16))
else:
self.bias = None

Expand Down Expand Up @@ -180,3 +185,72 @@ def forward(self, inputs):
)
out = out + self.bias if self.bias is not None else out
return out


class QMOE(torch.autograd.Function):
@staticmethod
def symbolic(
g,
x,
router_weights,
fc1_experts_weights,
fc1_scales,
fc2_experts_weights,
fc2_scales,
fc3_experts_weights,
fc3_scales,
router_probs,
activation_type,
block_size,
expert_weight_bits,
k,
):
qmoe_out = g.op(
"com.microsoft::QMoE",
x,
router_weights,
router_probs,
fc1_experts_weights,
fc1_scales,
fc2_experts_weights,
fc2_scales,
fc3_experts_weights,
fc3_scales,
outputs=1,
activation_type_s=activation_type, # <-- _s suffix for string
block_size_i=block_size,
expert_weight_bits_i=expert_weight_bits,
k_i=k,
)

# # Create axes=-1 as an explicit int64 constant tensor
# axes = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))

# # Compute mean of router_probs along the last axis, keepdims for broadcasting
# router_probs_mean = g.op("ReduceMean", router_probs, axes, keepdims_i=1)

# Multiply qmoe_out with the averaged router_probs
return qmoe_out
# return g.op("Mul", qmoe_out, router_probs_mean))

@staticmethod
def forward(
ctx,
x,
router_weights,
fc1_experts_weights,
fc1_scales,
fc2_experts_weights,
fc2_scales,
fc3_experts_weights,
fc3_scales,
router_probs,
activation_type,
block_size,
expert_weight_bits,
k,
):
# Dummy forward: simulate qmoe_out as zeros_like(x), then apply ReduceMean * Mul
qmoe_out = torch.zeros_like(x)
router_probs_mean = router_probs.mean(dim=-1, keepdim=True)
return qmoe_out * router_probs_mean
Loading