From d906314e0c328759c42757a3e564a848456d02b7 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sun, 21 Jun 2026 14:23:30 +0200 Subject: [PATCH] Tighten TQ4 SDPA input validation --- backends/cuda/tests/test_tq4_sdpa.py | 62 ++++++++ backends/cuda/triton/kernels/tq4_sdpa.py | 175 ++++++++++++++++++----- 2 files changed, 202 insertions(+), 35 deletions(-) diff --git a/backends/cuda/tests/test_tq4_sdpa.py b/backends/cuda/tests/test_tq4_sdpa.py index f4cc1d770ef..4c6008fb8fd 100644 --- a/backends/cuda/tests/test_tq4_sdpa.py +++ b/backends/cuda/tests/test_tq4_sdpa.py @@ -192,6 +192,21 @@ def _run_test( f"(B={B} H_q={H_q} H_kv={H_kv} Lq={Lq} Lk={Lk} D={D})", ) + def _make_valid_tq4_args(self, B=1, H_q=4, H_kv=4, Lq=1, Lk=64, D=64): + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + k_packed, k_norms = _compress(k, boundaries, rotation) + v_packed, v_norms = _compress(v, boundaries, rotation) + return q, k_packed, k_norms, v_packed, v_norms, centroids, rotation + # ------------------------------------------------------------------ # MHA (H_q == H_kv) # ------------------------------------------------------------------ @@ -678,6 +693,53 @@ def test_kv_len_clamp_128k(self): # Validation errors # ------------------------------------------------------------------ + def test_3d_query_rejected(self): + """3D query raises RuntimeError before shape unpacking.""" + args = self._make_valid_tq4_args() + q, k_p, k_n, v_p, v_n, centroids, rotation = args + q_3d = q.squeeze(2) + with self.assertRaisesRegex(RuntimeError, "query must be 4D"): + self.tq4_sdpa(q_3d, k_p, k_n, v_p, v_n, centroids, rotation) + + def test_3d_mask_rejected(self): + """3D attention mask raises RuntimeError before shape indexing.""" + args = self._make_valid_tq4_args() + q, k_p, k_n, v_p, v_n, centroids, rotation = args + mask_3d = torch.ones(1, 1, 64, dtype=torch.bool, device="cuda") + with self.assertRaisesRegex(RuntimeError, "attn_mask must be 4D"): + self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, rotation, mask_3d) + + def test_wrong_rotation_shape_rejected(self): + """Rotation shape must match query head_dim.""" + args = self._make_valid_tq4_args() + q, k_p, k_n, v_p, v_n, centroids, rotation = args + bad_rotation = rotation[:-1, :] + with self.assertRaisesRegex(RuntimeError, "rotation must have shape"): + self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, bad_rotation) + + def test_wrong_centroids_shape_rejected(self): + """Centroids shape must be length 16.""" + args = self._make_valid_tq4_args() + q, k_p, k_n, v_p, v_n, centroids, rotation = args + bad_centroids = centroids[:-1] + with self.assertRaisesRegex(RuntimeError, "centroids must have shape"): + self.tq4_sdpa(q, k_p, k_n, v_p, v_n, bad_centroids, rotation) + + def test_cpu_k_norms_with_cuda_query_rejected(self): + """k_norms must be on the same CUDA device as query.""" + args = self._make_valid_tq4_args() + q, k_p, k_n, v_p, v_n, centroids, rotation = args + with self.assertRaisesRegex(RuntimeError, "same CUDA device as query"): + self.tq4_sdpa(q, k_p, k_n.cpu(), v_p, v_n, centroids, rotation) + + def test_v_packed_shape_mismatch_rejected(self): + """v_packed must match the packed K layout.""" + args = self._make_valid_tq4_args() + q, k_p, k_n, v_p, v_n, centroids, rotation = args + bad_v_p = v_p[:, :, :-1, :] + with self.assertRaisesRegex(RuntimeError, "v_packed shape mismatch"): + self.tq4_sdpa(q, k_p, k_n, bad_v_p, v_n, centroids, rotation) + def test_hq_not_divisible_by_hkv_rejected(self): """H_Q not divisible by H_KV raises RuntimeError.""" D = 64 diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index e1576f7e446..7f8c4c4e717 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -614,25 +614,110 @@ def grid(meta): # --------------------------------------------------------------------------- -def _validate_tq4_inputs(query, k_packed, v_packed): +def _check_rank(tensor, name, expected_rank, shape_desc): + if tensor.dim() != expected_rank: + raise RuntimeError( + f"{name} must be {expected_rank}D {shape_desc}, got {tensor.dim()}D" + ) + + +def _check_same_cuda_device(tensor, name, expected_device): + if not tensor.is_cuda or tensor.device != expected_device: + raise RuntimeError( + f"{name} must be on the same CUDA device as query " + f"({expected_device}), got {tensor.device}" + ) + + +def _validate_tq4_inputs( + query, + k_packed, + k_norms, + v_packed, + v_norms, + centroids, + rotation, + attn_mask, + kv_len, +): """Validate tensor shapes, dtypes, and device for tq4_sdpa.""" - B, H_Q, N_Q, D = query.shape - B_kp, H_KV, N_KV, HALF_D = k_packed.shape + _check_rank(query, "query", 4, "[B, H, L, D]") + _check_rank(k_packed, "k_packed", 4, "[B, H, L, D//2]") + _check_rank(k_norms, "k_norms", 4, "[B, H, L, 1]") + _check_rank(v_packed, "v_packed", 4, "[B, H, L, D//2]") + _check_rank(v_norms, "v_norms", 4, "[B, H, L, 1]") + _check_rank(centroids, "centroids", 1, "[16]") + _check_rank(rotation, "rotation", 2, "[D, D]") + if attn_mask is not None: + _check_rank(attn_mask, "attn_mask", 4, "[B, 1, L_Q, L_KV]") + if kv_len is not None and kv_len.dim() > 1: + raise RuntimeError( + f"kv_len must be a scalar or 1D tensor with one element, " + f"got {kv_len.dim()}D" + ) if not query.is_cuda: raise RuntimeError("query must be a CUDA tensor") + expected_device = query.device + for name, tensor in ( + ("k_packed", k_packed), + ("k_norms", k_norms), + ("v_packed", v_packed), + ("v_norms", v_norms), + ("centroids", centroids), + ("rotation", rotation), + ): + _check_same_cuda_device(tensor, name, expected_device) + if attn_mask is not None: + _check_same_cuda_device(attn_mask, "attn_mask", expected_device) + if kv_len is not None: + _check_same_cuda_device(kv_len, "kv_len", expected_device) + if query.dtype != torch.bfloat16: raise RuntimeError(f"query must be bfloat16, got {query.dtype}") - if query.dim() != 4: - raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D") - if k_packed.dim() != 4 or v_packed.dim() != 4: - raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]") - if k_packed.dtype != torch.uint8 or v_packed.dtype != torch.uint8: - raise RuntimeError("k_packed and v_packed must be uint8") + if k_packed.dtype != torch.uint8: + raise RuntimeError(f"k_packed must be uint8, got {k_packed.dtype}") + if v_packed.dtype != torch.uint8: + raise RuntimeError(f"v_packed must be uint8, got {v_packed.dtype}") + if k_norms.dtype not in (torch.float32, torch.bfloat16): + raise RuntimeError( + f"k_norms must be float32 or bfloat16, got {k_norms.dtype}" + ) + if v_norms.dtype not in (torch.float32, torch.bfloat16): + raise RuntimeError( + f"v_norms must be float32 or bfloat16, got {v_norms.dtype}" + ) + if not torch.is_floating_point(centroids): + raise RuntimeError(f"centroids must be floating point, got {centroids.dtype}") + if not torch.is_floating_point(rotation): + raise RuntimeError(f"rotation must be floating point, got {rotation.dtype}") + if attn_mask is not None and attn_mask.dtype != torch.bool: + raise RuntimeError( + f"attn_mask must be bool, got {attn_mask.dtype}. " + "Additive float masks are not supported." + ) + if kv_len is not None and kv_len.dtype not in ( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ): + raise RuntimeError(f"kv_len must have integer dtype, got {kv_len.dtype}") + + B, H_Q, N_Q, D = query.shape + B_kp, H_KV, N_KV, HALF_D = k_packed.shape + + if v_packed.shape != k_packed.shape: + raise RuntimeError( + f"v_packed shape mismatch: expected {tuple(k_packed.shape)} to match " + f"k_packed, got {tuple(v_packed.shape)}" + ) if B_kp != B: raise RuntimeError( f"Batch dim mismatch: query has B={B}, k_packed has B={B_kp}" ) + if H_KV == 0: + raise RuntimeError("k_packed head dimension must be greater than 0") if H_Q % H_KV != 0: raise RuntimeError( f"H_Q must be a multiple of H_KV for GQA head mapping, " @@ -647,34 +732,46 @@ def _validate_tq4_inputs(query, k_packed, v_packed): f"HEAD_DIM must be a power of 2, got {D}. " "Non-power-of-2 head dims are not supported." ) - - -def _validate_tq4_mask(attn_mask, B, N_Q, N_KV): - """Validate attention mask for tq4_sdpa.""" - if attn_mask is None: - return - if attn_mask.dtype != torch.bool: + expected_norm_shape = (B, H_KV, N_KV, 1) + if k_norms.shape != expected_norm_shape: raise RuntimeError( - f"attn_mask must be bool, got {attn_mask.dtype}. " - "Additive float masks are not supported." + f"k_norms shape mismatch: expected {expected_norm_shape} to match " + f"k_packed layout, got {tuple(k_norms.shape)}" ) - if not attn_mask.is_cuda: - raise RuntimeError("attn_mask must be a CUDA tensor") - if attn_mask.shape[1] != 1: + if v_norms.shape != expected_norm_shape: raise RuntimeError( - f"attn_mask head dimension must be 1 (broadcast over heads); " - f"per-head masks are not supported. " - f"Got attn_mask.shape={attn_mask.shape}" + f"v_norms shape mismatch: expected {expected_norm_shape} to match " + f"v_packed layout, got {tuple(v_norms.shape)}" ) - if ( - attn_mask.shape[0] != B - or attn_mask.shape[2] != N_Q - or attn_mask.shape[3] != N_KV - ): + if centroids.shape != (16,): raise RuntimeError( - f"attn_mask shape mismatch: expected " - f"[B={B}, 1, L_Q={N_Q}, L_KV={N_KV}], " - f"got {attn_mask.shape}" + f"centroids must have shape (16,), got {tuple(centroids.shape)}" + ) + if rotation.shape != (D, D): + raise RuntimeError( + f"rotation must have shape ({D}, {D}), got {tuple(rotation.shape)}" + ) + + if attn_mask is not None: + if attn_mask.shape[1] != 1: + raise RuntimeError( + f"attn_mask head dimension must be 1 (broadcast over heads); " + f"per-head masks are not supported. " + f"Got attn_mask.shape={attn_mask.shape}" + ) + if ( + attn_mask.shape[0] != B + or attn_mask.shape[2] != N_Q + or attn_mask.shape[3] != N_KV + ): + raise RuntimeError( + f"attn_mask shape mismatch: expected " + f"[B={B}, 1, L_Q={N_Q}, L_KV={N_KV}], " + f"got {attn_mask.shape}" + ) + if kv_len is not None and kv_len.numel() != 1: + raise RuntimeError( + f"kv_len must contain exactly one element, got {kv_len.numel()}" ) @@ -736,13 +833,21 @@ def tq4_sdpa( Returns: [B, H_Q, L_Q, D] bf16 attention output """ - _validate_tq4_inputs(query, k_packed, v_packed) + _validate_tq4_inputs( + query, + k_packed, + k_norms, + v_packed, + v_norms, + centroids, + rotation, + attn_mask, + kv_len, + ) B, H_Q, N_Q, D = query.shape _, H_KV, N_KV, HALF_D = k_packed.shape - _validate_tq4_mask(attn_mask, B, N_Q, N_KV) - sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale) num_groups = H_Q // H_KV