From 5f3540c13cc1f8c10eb19523f32be2406de306ec Mon Sep 17 00:00:00 2001 From: Kushal Dulla Date: Thu, 14 May 2026 15:19:18 +0530 Subject: [PATCH 1/2] added head parallel kv blocking Signed-off-by: Kushal Dulla --- QEfficient/blocking/attention_blocking.py | 14 +- .../blocking/blocked_attention_forwards.py | 148 +++++++++++++- QEfficient/blocking/blocking_configurator.py | 3 + QEfficient/transformers/cache_utils.py | 186 ++++++++++++++++++ .../test_causal_lm_blocking_hqkv.py | 42 +++- 5 files changed, 385 insertions(+), 8 deletions(-) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 2ab5c03bec..9665a75f3a 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -20,6 +20,7 @@ blocked_h_mla_attention_forward, blocked_hqkv_attention_forward, blocked_kv_attention_forward, + blocked_kv_attention_forward_headpar_offline, blocked_kv_mla_attention_forward, blocked_q_attention_forward, blocked_qkv_attention_forward, @@ -44,6 +45,7 @@ class AttentionBlockingConfig: head_block_size: Optional[int] = None skip_kv: Optional[bool] = False num_batch_blocks: Optional[int] = None + kv_blocking_headpar_split: Optional[int] = None def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: @@ -59,6 +61,12 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: BlockingMode.BHQKV: blocked_bhqkv_attention_forward, } +# replace just the KV blocking strategy with headpar version +_STRATEGIES_HEADPAR: Dict[BlockingMode, Callable] = { + **_STRATEGIES, + BlockingMode.KV: blocked_kv_attention_forward_headpar_offline, +} + _STRATEGIES_MLA: Dict[BlockingMode, Callable] = { BlockingMode.KV: blocked_kv_mla_attention_forward, BlockingMode.H: blocked_h_mla_attention_forward, @@ -146,7 +154,10 @@ def generic_blocked_attention_interface( sliding_window=sliding_window, ) - strategy = _STRATEGIES.get(blocking_config.mode) + if blocking_config.kv_blocking_headpar_split is not None: + strategy = _STRATEGIES_HEADPAR.get(blocking_config.mode) + else: + strategy = _STRATEGIES.get(blocking_config.mode) attn_output, attn_weights = strategy( module=module, query=query, @@ -161,6 +172,7 @@ def generic_blocked_attention_interface( num_q_blocks=blocking_config.num_q_blocks, head_block_size=blocking_config.head_block_size, num_batch_blocks=blocking_config.num_batch_blocks, + configured_split=blocking_config.kv_blocking_headpar_split, score_mod=score_mod, position_bias=position_bias, sinks=sinks, diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 6aed6e49f9..5739a350db 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -42,6 +42,13 @@ def _normalize_int(value: Optional[torch.Tensor | int]) -> int: return int(value) if value is not None else 0 +def _get_headpar_split(configured_split: int, num_kv_groups: int) -> int: + # configured split = 0 used as the case to default to num_kv_groups, so whenever kv_blocking_headpar_split is passed, we know to do headpar + if configured_split == 0: + configured_split = None + return max(1, int(configured_split if configured_split is not None else num_kv_groups)) + + def update_running_softmax( current_max: torch.Tensor, attn_weights_block: torch.Tensor, @@ -183,7 +190,7 @@ def blocked_kv_attention_forward( if mask_block is None: mask_block = causal_mask_block else: - mask_block = mask_block.to(torch.bool) | causal_mask_block + mask_block = mask_block.to(torch.bool) | causal_mask_blockgiot if mask_block is not None: attn_weights_block = torch.where(mask_block, masked_tensor, attn_weights_block) @@ -202,6 +209,145 @@ def blocked_kv_attention_forward( return attn_output, attn_weights +def blocked_kv_attention_forward_headpar_offline( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + configured_split: Optional[int] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Head-parallel block softmax: K is split into `split` chunks along the + # ctx dimension, computed in parallel as a 5D matmul, then two-stage + # merged (across kv-blocks, then across splits). + batch_size, num_heads, seq_len, head_dim = query.shape + num_kv_groups = getattr(module, "num_key_value_groups", None) + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + position_ids = cache_kwargs.get("position_ids") + num_kv_heads = num_heads // num_kv_groups + split = _get_headpar_split(configured_split, num_kv_groups) + num_kv_blocks = max(1, num_kv_blocks) + kv_block_size = -(-past_seen_tokens // num_kv_blocks) + current_position = position_ids.max(dim=-1).values + + query_folded = query.reshape(batch_size, num_kv_heads, seq_len * num_kv_groups, head_dim) + query_5d = query_folded.unsqueeze(2).expand(batch_size, num_kv_heads, split, seq_len * num_kv_groups, head_dim) + + max_blocks = [] + sum_blocks = [] + out_blocks = [] + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + k_block = past_key_value.read_only_blocked_K(start_index, end_index, layer_idx, cache_kwargs) + block_len = kv_len_block + pad_len = 0 + if block_len % split != 0: + pad_len = split - (block_len % split) + k_block = nn.functional.pad(k_block, (0, 0, 0, pad_len)) + block_len += pad_len + split_block_len = block_len // split + + key_5d = k_block.view(batch_size, num_kv_heads, split, split_block_len, head_dim) + attn_weights_block = torch.matmul(query_5d, key_5d.transpose(-1, -2)) * scaling + + if pad_len > 0: + chunk_start = torch.arange(split, device=query.device) * split_block_len + valid_in_chunk = kv_len_block - chunk_start + key_idx = torch.arange(split_block_len, device=query.device) + pad_mask = key_idx.unsqueeze(0) >= valid_in_chunk.unsqueeze(1) + attn_weights_block = attn_weights_block.masked_fill(pad_mask.view(1, 1, split, 1, split_block_len), -3.0e4) + + key_abs = ( + start_index + + torch.arange(split, device=query.device)[:, None] * split_block_len + + torch.arange(split_block_len, device=query.device)[None, :] + ) + query_pos = position_ids.repeat(1, num_kv_groups) + causal_mask = key_abs[None, :, None, :] > query_pos[:, None, :, None] + attn_weights_block = attn_weights_block.masked_fill(causal_mask.unsqueeze(1), -3.0e4) + + max_block = attn_weights_block.max(dim=-1).values + exp_block = torch.exp(attn_weights_block - max_block.unsqueeze(-1)) + if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + max_block = torch.where(skip_future, torch.full_like(max_block, MIN_MASKED_ATTENTION_VALUE), max_block) + exp_block = torch.where(skip_future, torch.zeros_like(exp_block), exp_block) + + v_block = past_key_value.read_only_blocked_V(start_index, end_index, layer_idx, cache_kwargs) + if pad_len > 0: + v_block = nn.functional.pad(v_block, (0, 0, 0, pad_len)) + value_5d = v_block.view(batch_size, num_kv_heads, split, split_block_len, head_dim) + sum_block = exp_block.sum(dim=-1) + out_block = torch.matmul(exp_block, value_5d) + if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + sum_block = torch.where(skip_future, torch.zeros_like(sum_block), sum_block) + out_block = torch.where(skip_future, torch.zeros_like(out_block), out_block) + + max_blocks.append(max_block) + sum_blocks.append(sum_block) + out_blocks.append(out_block) + + max_stacked = torch.stack(max_blocks) + sum_stacked = torch.stack(sum_blocks) + out_stacked = torch.stack(out_blocks) + block_max = max_stacked.max(dim=0).values + block_weight = torch.exp(max_stacked - block_max.unsqueeze(0)) + block_sum = (block_weight * sum_stacked).sum(dim=0) + block_out = (block_weight.unsqueeze(-1) * out_stacked).sum(dim=0) + + split_max = block_max.max(dim=2).values + split_weight = torch.exp(block_max - split_max.unsqueeze(2)) + split_sum = (split_weight * block_sum).sum(dim=2) + split_out = (split_weight.unsqueeze(-1) * block_out).sum(dim=2) + + if sinks is not None: + sinks_logits = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + + # Fold heads the same way as query: [B, H, QL, 1] -> [B, Hkv, QL*num_kv_groups, 1] + sinks_folded = sinks_logits.reshape(batch_size, num_kv_heads, seq_len * num_kv_groups, 1) + sink_logits = sinks_folded.squeeze(-1) # [B, Hkv, QL*num_kv_groups] + + new_max = torch.maximum(split_max, sink_logits) + scale_old = torch.exp(split_max - new_max) + scale_sink = torch.exp(sink_logits - new_max) + + split_sum = split_sum * scale_old + scale_sink + split_out = split_out * scale_old.unsqueeze(-1) + split_max = new_max + + output = split_out / split_sum.unsqueeze(-1) + attn_output = output.view(batch_size, num_kv_heads, num_kv_groups, seq_len, head_dim).reshape( + batch_size, num_heads, seq_len, head_dim + ) + return attn_output.transpose(1, 2).contiguous(), None + + def blocked_qkv_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index deed73a7bf..a2699694e3 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -360,4 +360,7 @@ def build_transformer_blocking_config_for_transform( if qaic_config.get("skip_kv", False) and enable_blocking: blocking_config.skip_kv = qaic_config.get("skip_kv") + if qaic_config.get("kv_blocking_headpar_split", None) is not None and enable_blocking: + blocking_config.kv_blocking_headpar_split = qaic_config.get("kv_blocking_headpar_split") + return blocking_config diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 799717bf83..f540bb5ffc 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -132,6 +132,91 @@ def read_only(self, cache_kwargs): v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + def read_only_blocked_K(self, start_index, end_index, cache_kwargs): + """ + Reads the `key_states` for each KV block. + + Parameters: + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + start_index (`int`): + Start index of the K/V block to read + + end_index (`int`): + End index of the K/V block to read + + Return: + the updated key states. + """ + # Gather + k_out = self.keys + if k_out is not None: + self._mark_initialized(k_out) + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + batch, num_kv_heads, _, _ = k_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices) + else: + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices) + + return k_out + + def read_only_blocked_V(self, start_index, end_index, cache_kwargs): + """ + Reads the `value_states` for the layer for each KV block. + + Parameters: + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + start_index (`int`): + Start index of the K/V block to read + + end_index (`int`): + End index of the K/V block to read + + Return: + the updated value states. + """ + # Gather + v_out = self.values + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + batch, num_kv_heads, _, _ = v_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices) + else: + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return v_out + def read_only_blockedKV(self, start_index, end_index, cache_kwargs): """ Reads the `key_states` and `value_states` for the layer for each KV block. @@ -581,6 +666,44 @@ def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs): """ return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs) + def read_only_blocked_K(self, start_index, end_index, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + start_index (`int`): + Start index of the K/V block to read + end_index (`int`): + End index of the K/V block to read + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only_blocked_K(start_index, end_index, cache_kwargs) + + def read_only_blocked_V(self, start_index, end_index, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + start_index (`int`): + Start index of the K/V block to read + end_index (`int`): + End index of the K/V block to read + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only_blocked_V(start_index, end_index, cache_kwargs) + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -1122,6 +1245,69 @@ def read_only_blockedKV( v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + def read_only_blocked_K( + self, + start_idx: torch.Tensor, + end_idx: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs + + k_out = self.key_cache[layer_idx] + + batch, num_kv_heads, _, _ = k_out.shape + + ctx_indices = torch.arange(start=start_idx, end=end_idx)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices) + else: + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices) + + return k_out + + def read_only_blocked_V( + self, + start_idx: torch.Tensor, + end_idx: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs + + v_out = self.value_cache[layer_idx] + + batch, num_kv_heads, _, _ = v_out.shape + + ctx_indices = torch.arange(start=start_idx, end=end_idx)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices) + else: + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return v_out + def update( self, key_states: torch.Tensor, diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py index 4bf067e7c4..8530ee6c33 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py @@ -31,7 +31,6 @@ @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -48,6 +47,12 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 ) + # kv blocking only, head parallel + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -77,7 +82,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -94,6 +98,12 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup ) + # kv blocking only, head parallel + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -123,7 +133,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -149,6 +158,12 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup ) + # kv blocking only, head parallel + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -178,7 +193,6 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -203,6 +217,12 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m num_devices=4, ) + # kv blocking only, head parallel + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -244,7 +264,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -269,6 +288,12 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma continuous_batching=True, ) + # kv blocking only, head parallel + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -310,7 +335,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -345,6 +369,12 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup=manual_cleanup, continuous_batching=True, ) + + # kv blocking only, head parallel + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) From 76035e426e6c55a8b9c35ecc65cca86979c2aa3f Mon Sep 17 00:00:00 2001 From: Kushal Dulla Date: Thu, 14 May 2026 15:24:10 +0530 Subject: [PATCH 2/2] formatting issues Signed-off-by: Kushal Dulla --- QEfficient/blocking/blocked_attention_forwards.py | 2 +- .../models/causal_lm_models/test_causal_lm_blocking_hqkv.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 5739a350db..77d6c2dd7e 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -190,7 +190,7 @@ def blocked_kv_attention_forward( if mask_block is None: mask_block = causal_mask_block else: - mask_block = mask_block.to(torch.bool) | causal_mask_blockgiot + mask_block = mask_block.to(torch.bool) | causal_mask_block if mask_block is not None: attn_weights_block = torch.where(mask_block, masked_tensor, attn_weights_block) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py index 8530ee6c33..eb08e2c8f8 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py @@ -293,7 +293,7 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 ) - + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -369,7 +369,7 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup=manual_cleanup, continuous_batching=True, ) - + # kv blocking only, head parallel qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, kv_blocking_headpar_split=0) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(