Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
CtxScatter3D,
CtxScatterFunc,
CtxScatterFunc3D,
CtxScatter3DInt,
CtxScatterFunc3DInt,
CtxScatterFunc3DGeneralized,
CtxGatherFunc3DGeneralized,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherBlockedKVCB,
Expand All @@ -39,6 +43,7 @@
CtxScatterFuncCB,
CtxScatterFuncCB3D,
)
from QEfficient.customop.quantization_ops import CastToUInt4, CastToUInt4Func
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
from QEfficient.utils.constants import FILE_CHUNK_SIZE_DEFAULT, ONNX_EXPORT_OPSET, SIZE_THRESHOLD_DEFAULT

Expand Down Expand Up @@ -100,6 +105,10 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CastToUInt4": (CastToUInt4Func, CastToUInt4),
"CtxScatterFunc3DInt": (CtxScatterFunc3DInt, CtxScatter3DInt),
"CtxScatterFunc3DGeneralized":(CtxScatterFunc3DGeneralized, CtxScatter3D),
"CtxGatherFunc3DGeneralized": (CtxGatherFunc3DGeneralized, CtxGather3D),
}

@classmethod
Expand Down
12 changes: 10 additions & 2 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# ----------------------------------------------------------------------------
from types import MethodType
from typing import Callable, Dict, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type

from torch import nn

Expand Down Expand Up @@ -97,6 +97,7 @@ class ModuleMutatorTransform(PytorchTransform):
"""

_match_class: nn.Module
_match_string: Optional[str] = None

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
Expand Down Expand Up @@ -135,7 +136,14 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__)
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
parts = orig_method_name.split(".")
if len(parts) > 1:
target = module
for part in parts[:-1]:
target = getattr(target, part)
setattr(target, parts[-1], MethodType(mapped_method, target))
else:
setattr(module, orig_method_name, MethodType(mapped_method, module))

if hasattr(module, "__qeff_init__"):
module.__qeff_init__()
Expand Down
5 changes: 5 additions & 0 deletions QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
blocked_hqkv_attention_forward,
blocked_kv_attention_forward,
blocked_kv_mla_attention_forward,
blocked_kv_par_mla_attention_forward,
blocked_q_attention_forward,
blocked_qkv_attention_forward,
)
Expand All @@ -34,6 +35,7 @@ class BlockingMode(str, Enum):
QKV = "qkv"
HQKV = "hqkv"
BHQKV = "bhqkv"
PAR = "par"


@dataclass
Expand All @@ -44,6 +46,7 @@ class AttentionBlockingConfig:
head_block_size: Optional[int] = None
skip_kv: Optional[bool] = False
num_batch_blocks: Optional[int] = None
par_num_split: Optional[int] = None


def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
Expand All @@ -62,6 +65,7 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
_STRATEGIES_MLA: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_mla_attention_forward,
BlockingMode.H: blocked_h_mla_attention_forward,
BlockingMode.PAR: blocked_kv_par_mla_attention_forward
}


Expand Down Expand Up @@ -224,6 +228,7 @@ def generic_blocked_mla_attention_interface(
num_q_blocks=blocking_config.num_q_blocks,
head_block_size=blocking_config.head_block_size,
num_batch_blocks=blocking_config.num_batch_blocks,
par_num_split=blocking_config.par_num_split,
score_mod=score_mod,
position_bias=position_bias,
sinks=sinks,
Expand Down
169 changes: 169 additions & 0 deletions QEfficient/blocking/blocked_attention_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,12 @@ def blocked_kv_mla_attention_forward(
) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512]
else:
knope = torch.matmul(compressed_kv_block, per_head_k_up_normal)
if k_heads == 1:
k_pe_block = (
k_pe_block.unsqueeze(1)
.expand(-1, num_heads, -1, -1, -1)
.reshape(batch_size, num_heads, -1, module.config.qk_rope_head_dim)
)
krope_nope = torch.cat((knope, k_pe_block), dim=-1)
attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling
attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block)
Expand All @@ -951,6 +957,169 @@ def blocked_kv_mla_attention_forward(
return attn_output, attn_weights


def blocked_kv_par_mla_attention_forward(
module: nn.Module,
query: torch.Tensor,
per_head_v_up: torch.Tensor,
per_head_k_up_normal: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
num_kv_blocks: int,
par_num_split: int,
cache_kwargs: Dict[str, Any],
layer_idx: int,
compressed_kvs,
mla_absorption: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, num_heads, q_len, query_dim = query.shape
kv_lora_rank = module.config.kv_lora_rank

if mla_absorption is not None:
absorption = mla_absorption.get("absorption", False)
else:
absorption = False

position_ids = cache_kwargs.get("position_ids")

if absorption:
num_key_value_heads = getattr(module, "num_key_value_heads", 1)
num_heads_per_kv = num_heads // num_key_value_heads
else:
num_key_value_heads = num_heads
num_heads_per_kv = 1

split = par_num_split
split = max(1, _normalize_int(split))

q_fold = query.reshape(batch_size, num_key_value_heads, q_len * num_heads_per_kv, query_dim)
q_5d = q_fold.unsqueeze(2).expand(
batch_size, num_key_value_heads, split, q_len * num_heads_per_kv, query_dim
)

ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2]
kv_block_size = -(-ctx_len // num_kv_blocks)
split_block_size = -(-kv_block_size // split)
kv_offsets = (
torch.arange(split, device=query.device)[:, None] * split_block_size
+ torch.arange(split_block_size, device=query.device)[None, :]
).view(1, 1, split, 1, split_block_size)

masked_tensor = torch.tensor(-3.0e4, dtype=query.dtype, device=query.device)
current_position = position_ids.max(dim=-1).values
skip_kv = True

max_blocks = []
sum_blocks = []
output_blocks = []

for block_idx in range(num_kv_blocks):
start_index = block_idx * kv_block_size
if block_idx == num_kv_blocks - 1:
kv_len_block = ctx_len - start_index
else:
kv_len_block = kv_block_size
end_index = start_index + kv_len_block

skip_future = None
if skip_kv:
skip_future = (torch.tensor(start_index, device=query.device) > current_position).all()
if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing():
if skip_future.item():
break

compressed_kv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, layer_idx, cache_kwargs)
k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, layer_idx, cache_kwargs)

if absorption:
key_block = torch.cat((compressed_kv_block, k_pe_block), dim=-1)
value_block = compressed_kv_block
else:
original_kv_heads = getattr(module, "num_key_value_heads", 1)
num_repeats = num_heads // original_kv_heads
compressed_kv_block = (
compressed_kv_block.unsqueeze(2)
.expand(-1, original_kv_heads, num_repeats, -1, -1)
.reshape(batch_size, num_heads, kv_len_block, kv_lora_rank)
)
k_pe_block = (
k_pe_block.unsqueeze(2)
.expand(-1, original_kv_heads, num_repeats, -1, -1)
.reshape(batch_size, num_heads, kv_len_block, module.config.qk_rope_head_dim)
)
k_nope_block = torch.matmul(compressed_kv_block, per_head_k_up_normal)
key_block = torch.cat((k_nope_block, k_pe_block), dim=-1)
value_block = compressed_kv_block

pad = 0
padded_kv_len = kv_len_block
if padded_kv_len % split != 0:
pad = split - (padded_kv_len % split)
key_block = F.pad(key_block, (0, 0, 0, pad))
value_block = F.pad(value_block, (0, 0, 0, pad))
padded_kv_len += pad

per_split_kv_len = padded_kv_len // split
key_5d = key_block.view(batch_size, num_key_value_heads, split, per_split_kv_len, query_dim)
value_5d = value_block.view(batch_size, num_key_value_heads, split, per_split_kv_len, kv_lora_rank)

attn_weights_block = torch.matmul(q_5d, key_5d.transpose(-1, -2)) * scaling

if pad > 0:
chunk_start = torch.arange(split, device=query.device) * per_split_kv_len
valid_in_chunk = kv_len_block - chunk_start
kv_indices = torch.arange(per_split_kv_len, device=query.device)
pad_mask = kv_indices.unsqueeze(0) >= valid_in_chunk.unsqueeze(1)
attn_weights_block = torch.where(
pad_mask.view(1, 1, split, 1, per_split_kv_len), masked_tensor, attn_weights_block
)

offsets = kv_offsets if per_split_kv_len == split_block_size else kv_offsets[:, :, :, :, :per_split_kv_len]
causal_mask = offsets > (position_ids - start_index)[:, None, None, :, None]
attn_weights_block = torch.where(causal_mask, masked_tensor, attn_weights_block)

block_max = attn_weights_block.max(dim=-1).values
block_exp = torch.exp(attn_weights_block - block_max.unsqueeze(-1))
if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()):
block_max = torch.where(
skip_future, torch.full_like(block_max, float(MIN_MASKED_ATTENTION_VALUE)), block_max
)
block_exp = torch.where(skip_future, torch.zeros_like(block_exp), block_exp)

block_sum = block_exp.sum(dim=-1)
block_output = torch.matmul(block_exp, value_5d)
if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()):
block_sum = torch.where(skip_future, torch.zeros_like(block_sum), block_sum)
block_output = torch.where(skip_future, torch.zeros_like(block_output), block_output)

max_blocks.append(block_max)
sum_blocks.append(block_sum)
output_blocks.append(block_output)

max_stacked = torch.stack(max_blocks)
sum_stacked = torch.stack(sum_blocks)
output_stacked = torch.stack(output_blocks)

max_across_blocks = max_stacked.max(dim=0).values
weights_across_blocks = torch.exp(max_stacked - max_across_blocks.unsqueeze(0))
sum_across_blocks = (weights_across_blocks * sum_stacked).sum(dim=0)
output_across_blocks = (weights_across_blocks.unsqueeze(-1) * output_stacked).sum(dim=0)

max_across_splits = max_across_blocks.max(dim=2).values
weights_across_splits = torch.exp(max_across_blocks - max_across_splits.unsqueeze(2))
sum_across_splits = (weights_across_splits * sum_across_blocks).sum(dim=2)
output_across_splits = (weights_across_splits.unsqueeze(-1) * output_across_blocks).sum(dim=2)

output = output_across_splits / sum_across_splits.unsqueeze(-1)
output = output.view(batch_size, num_key_value_heads, num_heads_per_kv, q_len, kv_lora_rank).reshape(
batch_size, num_heads, q_len, kv_lora_rank
)
attn_output = torch.matmul(output, per_head_v_up)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, None


def blocked_h_mla_attention_forward(
module: nn.Module,
q_a_proj_out: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion QEfficient/blocking/blocking_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def build_transformer_blocking_config(
)

if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []):
if "kv" in blocking_mode:
if "kv" in blocking_mode or "par" in blocking_mode:
attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len)

resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv")
Expand Down Expand Up @@ -343,6 +343,10 @@ def build_transformer_blocking_config_for_transform(
if qaic_config.get("num_batch_blocks", False) and enable_blocking and "b" in blocking_mode:
mode_from_config = "b" + mode_from_config
blocking_config.num_batch_blocks = _get_valid_num_blocks(qaic_config, "num_batch_blocks")
if qaic_config.get("par_num_split", False) and qaic_config.get("num_kv_blocks", False) and enable_blocking and "par" in blocking_mode:
mode_from_config = "par" + mode_from_config
blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks")
blocking_config.par_num_split = _get_valid_num_blocks(qaic_config, "par_num_split")

# check if qaic config did not provide any blocking details
if mode_from_config == "":
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from QEfficient.customop.ctx_scatter_gather import (
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFunc3DGeneralized,
CtxGatherFuncBlockedKV,
CtxScatterFunc,
CtxScatterFunc3D,
CtxScatterFunc3DGeneralized,
CtxScatterFunc3DInt,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherFuncBlockedKVCB,
Expand All @@ -26,7 +29,10 @@
"CtxGatherFuncBlockedKV",
"CtxScatterFunc",
"CtxGatherFunc3D",
"CtxGatherFunc3DGeneralized",
"CtxScatterFunc3D",
"CtxScatterFunc3DGeneralized",
"CtxScatterFunc3DInt",
"CustomRMSNormAIC",
"GemmaCustomRMSNormAIC",
"CtxGatherFuncCB",
Expand Down
Loading