Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
blocked_h_mla_attention_forward,
blocked_hqkv_attention_forward,
blocked_kv_attention_forward,
blocked_kv_attention_forward_headpar_offline,
blocked_kv_mla_attention_forward,
blocked_q_attention_forward,
blocked_qkv_attention_forward,
Expand All @@ -44,6 +45,7 @@ class AttentionBlockingConfig:
head_block_size: Optional[int] = None
skip_kv: Optional[bool] = False
num_batch_blocks: Optional[int] = None
kv_blocking_headpar_split: Optional[int] = None


def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
Expand All @@ -59,6 +61,12 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
BlockingMode.BHQKV: blocked_bhqkv_attention_forward,
}

# replace just the KV blocking strategy with headpar version
_STRATEGIES_HEADPAR: Dict[BlockingMode, Callable] = {
**_STRATEGIES,
BlockingMode.KV: blocked_kv_attention_forward_headpar_offline,
}

_STRATEGIES_MLA: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_mla_attention_forward,
BlockingMode.H: blocked_h_mla_attention_forward,
Expand Down Expand Up @@ -146,7 +154,10 @@ def generic_blocked_attention_interface(
sliding_window=sliding_window,
)

strategy = _STRATEGIES.get(blocking_config.mode)
if blocking_config.kv_blocking_headpar_split is not None:
strategy = _STRATEGIES_HEADPAR.get(blocking_config.mode)
else:
strategy = _STRATEGIES.get(blocking_config.mode)
attn_output, attn_weights = strategy(
module=module,
query=query,
Expand All @@ -161,6 +172,7 @@ def generic_blocked_attention_interface(
num_q_blocks=blocking_config.num_q_blocks,
head_block_size=blocking_config.head_block_size,
num_batch_blocks=blocking_config.num_batch_blocks,
configured_split=blocking_config.kv_blocking_headpar_split,
score_mod=score_mod,
position_bias=position_bias,
sinks=sinks,
Expand Down
146 changes: 146 additions & 0 deletions QEfficient/blocking/blocked_attention_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def _normalize_int(value: Optional[torch.Tensor | int]) -> int:
return int(value) if value is not None else 0


def _get_headpar_split(configured_split: int, num_kv_groups: int) -> int:
# configured split = 0 used as the case to default to num_kv_groups, so whenever kv_blocking_headpar_split is passed, we know to do headpar
if configured_split == 0:
configured_split = None
return max(1, int(configured_split if configured_split is not None else num_kv_groups))


def update_running_softmax(
current_max: torch.Tensor,
attn_weights_block: torch.Tensor,
Expand Down Expand Up @@ -202,6 +209,145 @@ def blocked_kv_attention_forward(
return attn_output, attn_weights


def blocked_kv_attention_forward_headpar_offline(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
num_kv_blocks: int,
cache_kwargs: Dict[str, Any],
layer_idx: int,
past_key_value: Cache,
*,
use_causal_mask: bool = False,
sliding_window: Optional[int] = None,
skip_kv: bool = False,
position_bias: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
configured_split: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Head-parallel block softmax: K is split into `split` chunks along the
# ctx dimension, computed in parallel as a 5D matmul, then two-stage
# merged (across kv-blocks, then across splits).
batch_size, num_heads, seq_len, head_dim = query.shape
num_kv_groups = getattr(module, "num_key_value_groups", None)
past_seen_tokens = cache_kwargs.get("past_seen_tokens")
position_ids = cache_kwargs.get("position_ids")
num_kv_heads = num_heads // num_kv_groups
split = _get_headpar_split(configured_split, num_kv_groups)
num_kv_blocks = max(1, num_kv_blocks)
kv_block_size = -(-past_seen_tokens // num_kv_blocks)
current_position = position_ids.max(dim=-1).values

query_folded = query.reshape(batch_size, num_kv_heads, seq_len * num_kv_groups, head_dim)
query_5d = query_folded.unsqueeze(2).expand(batch_size, num_kv_heads, split, seq_len * num_kv_groups, head_dim)

max_blocks = []
sum_blocks = []
out_blocks = []

for j in range(num_kv_blocks):
start_index = j * kv_block_size
if j == num_kv_blocks - 1:
kv_len_block = past_seen_tokens - start_index
else:
kv_len_block = kv_block_size
end_index = start_index + kv_len_block

skip_future = None
if skip_kv:
skip_future = (torch.tensor(start_index, device=query.device) > current_position).all()
# Eager mode Only
if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing():
if skip_future.item():
break

k_block = past_key_value.read_only_blocked_K(start_index, end_index, layer_idx, cache_kwargs)
block_len = kv_len_block
pad_len = 0
if block_len % split != 0:
pad_len = split - (block_len % split)
k_block = nn.functional.pad(k_block, (0, 0, 0, pad_len))
block_len += pad_len
split_block_len = block_len // split

key_5d = k_block.view(batch_size, num_kv_heads, split, split_block_len, head_dim)
attn_weights_block = torch.matmul(query_5d, key_5d.transpose(-1, -2)) * scaling

if pad_len > 0:
chunk_start = torch.arange(split, device=query.device) * split_block_len
valid_in_chunk = kv_len_block - chunk_start
key_idx = torch.arange(split_block_len, device=query.device)
pad_mask = key_idx.unsqueeze(0) >= valid_in_chunk.unsqueeze(1)
attn_weights_block = attn_weights_block.masked_fill(pad_mask.view(1, 1, split, 1, split_block_len), -3.0e4)

key_abs = (
start_index
+ torch.arange(split, device=query.device)[:, None] * split_block_len
+ torch.arange(split_block_len, device=query.device)[None, :]
)
query_pos = position_ids.repeat(1, num_kv_groups)
causal_mask = key_abs[None, :, None, :] > query_pos[:, None, :, None]
attn_weights_block = attn_weights_block.masked_fill(causal_mask.unsqueeze(1), -3.0e4)

max_block = attn_weights_block.max(dim=-1).values
exp_block = torch.exp(attn_weights_block - max_block.unsqueeze(-1))
if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()):
max_block = torch.where(skip_future, torch.full_like(max_block, MIN_MASKED_ATTENTION_VALUE), max_block)
exp_block = torch.where(skip_future, torch.zeros_like(exp_block), exp_block)

v_block = past_key_value.read_only_blocked_V(start_index, end_index, layer_idx, cache_kwargs)
if pad_len > 0:
v_block = nn.functional.pad(v_block, (0, 0, 0, pad_len))
value_5d = v_block.view(batch_size, num_kv_heads, split, split_block_len, head_dim)
sum_block = exp_block.sum(dim=-1)
out_block = torch.matmul(exp_block, value_5d)
if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()):
sum_block = torch.where(skip_future, torch.zeros_like(sum_block), sum_block)
out_block = torch.where(skip_future, torch.zeros_like(out_block), out_block)

max_blocks.append(max_block)
sum_blocks.append(sum_block)
out_blocks.append(out_block)

max_stacked = torch.stack(max_blocks)
sum_stacked = torch.stack(sum_blocks)
out_stacked = torch.stack(out_blocks)
block_max = max_stacked.max(dim=0).values
block_weight = torch.exp(max_stacked - block_max.unsqueeze(0))
block_sum = (block_weight * sum_stacked).sum(dim=0)
block_out = (block_weight.unsqueeze(-1) * out_stacked).sum(dim=0)

split_max = block_max.max(dim=2).values
split_weight = torch.exp(block_max - split_max.unsqueeze(2))
split_sum = (split_weight * block_sum).sum(dim=2)
split_out = (split_weight.unsqueeze(-1) * block_out).sum(dim=2)

if sinks is not None:
sinks_logits = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1)

# Fold heads the same way as query: [B, H, QL, 1] -> [B, Hkv, QL*num_kv_groups, 1]
sinks_folded = sinks_logits.reshape(batch_size, num_kv_heads, seq_len * num_kv_groups, 1)
sink_logits = sinks_folded.squeeze(-1) # [B, Hkv, QL*num_kv_groups]

new_max = torch.maximum(split_max, sink_logits)
scale_old = torch.exp(split_max - new_max)
scale_sink = torch.exp(sink_logits - new_max)

split_sum = split_sum * scale_old + scale_sink
split_out = split_out * scale_old.unsqueeze(-1)
split_max = new_max

output = split_out / split_sum.unsqueeze(-1)
attn_output = output.view(batch_size, num_kv_heads, num_kv_groups, seq_len, head_dim).reshape(
batch_size, num_heads, seq_len, head_dim
)
return attn_output.transpose(1, 2).contiguous(), None


def blocked_qkv_attention_forward(
module: nn.Module,
query: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/blocking/blocking_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,7 @@ def build_transformer_blocking_config_for_transform(
if qaic_config.get("skip_kv", False) and enable_blocking:
blocking_config.skip_kv = qaic_config.get("skip_kv")

if qaic_config.get("kv_blocking_headpar_split", None) is not None and enable_blocking:
blocking_config.kv_blocking_headpar_split = qaic_config.get("kv_blocking_headpar_split")

return blocking_config
Loading
Loading