Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions backends/cuda/tests/test_tq4_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ def _run_test(
f"(B={B} H_q={H_q} H_kv={H_kv} Lq={Lq} Lk={Lk} D={D})",
)

def _make_valid_tq4_args(self, B=1, H_q=4, H_kv=4, Lq=1, Lk=64, D=64):
torch.manual_seed(42)
centroids, boundaries, rotation = _make_codebook_and_rotation(D)
centroids, boundaries, rotation = (
centroids.cuda(),
boundaries.cuda(),
rotation.cuda(),
)
q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda")
k_packed, k_norms = _compress(k, boundaries, rotation)
v_packed, v_norms = _compress(v, boundaries, rotation)
return q, k_packed, k_norms, v_packed, v_norms, centroids, rotation

# ------------------------------------------------------------------
# MHA (H_q == H_kv)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -678,6 +693,53 @@ def test_kv_len_clamp_128k(self):
# Validation errors
# ------------------------------------------------------------------

def test_3d_query_rejected(self):
"""3D query raises RuntimeError before shape unpacking."""
args = self._make_valid_tq4_args()
q, k_p, k_n, v_p, v_n, centroids, rotation = args
q_3d = q.squeeze(2)
with self.assertRaisesRegex(RuntimeError, "query must be 4D"):
self.tq4_sdpa(q_3d, k_p, k_n, v_p, v_n, centroids, rotation)

def test_3d_mask_rejected(self):
"""3D attention mask raises RuntimeError before shape indexing."""
args = self._make_valid_tq4_args()
q, k_p, k_n, v_p, v_n, centroids, rotation = args
mask_3d = torch.ones(1, 1, 64, dtype=torch.bool, device="cuda")
with self.assertRaisesRegex(RuntimeError, "attn_mask must be 4D"):
self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, rotation, mask_3d)

def test_wrong_rotation_shape_rejected(self):
"""Rotation shape must match query head_dim."""
args = self._make_valid_tq4_args()
q, k_p, k_n, v_p, v_n, centroids, rotation = args
bad_rotation = rotation[:-1, :]
with self.assertRaisesRegex(RuntimeError, "rotation must have shape"):
self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, bad_rotation)

def test_wrong_centroids_shape_rejected(self):
"""Centroids shape must be length 16."""
args = self._make_valid_tq4_args()
q, k_p, k_n, v_p, v_n, centroids, rotation = args
bad_centroids = centroids[:-1]
with self.assertRaisesRegex(RuntimeError, "centroids must have shape"):
self.tq4_sdpa(q, k_p, k_n, v_p, v_n, bad_centroids, rotation)

def test_cpu_k_norms_with_cuda_query_rejected(self):
"""k_norms must be on the same CUDA device as query."""
args = self._make_valid_tq4_args()
q, k_p, k_n, v_p, v_n, centroids, rotation = args
with self.assertRaisesRegex(RuntimeError, "same CUDA device as query"):
self.tq4_sdpa(q, k_p, k_n.cpu(), v_p, v_n, centroids, rotation)

def test_v_packed_shape_mismatch_rejected(self):
"""v_packed must match the packed K layout."""
args = self._make_valid_tq4_args()
q, k_p, k_n, v_p, v_n, centroids, rotation = args
bad_v_p = v_p[:, :, :-1, :]
with self.assertRaisesRegex(RuntimeError, "v_packed shape mismatch"):
self.tq4_sdpa(q, k_p, k_n, bad_v_p, v_n, centroids, rotation)

def test_hq_not_divisible_by_hkv_rejected(self):
"""H_Q not divisible by H_KV raises RuntimeError."""
D = 64
Expand Down
175 changes: 140 additions & 35 deletions backends/cuda/triton/kernels/tq4_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,25 +614,110 @@ def grid(meta):
# ---------------------------------------------------------------------------


def _validate_tq4_inputs(query, k_packed, v_packed):
def _check_rank(tensor, name, expected_rank, shape_desc):
if tensor.dim() != expected_rank:
raise RuntimeError(
f"{name} must be {expected_rank}D {shape_desc}, got {tensor.dim()}D"
)


def _check_same_cuda_device(tensor, name, expected_device):
if not tensor.is_cuda or tensor.device != expected_device:
raise RuntimeError(
f"{name} must be on the same CUDA device as query "
f"({expected_device}), got {tensor.device}"
)


def _validate_tq4_inputs(
query,
k_packed,
k_norms,
v_packed,
v_norms,
centroids,
rotation,
attn_mask,
kv_len,
):
"""Validate tensor shapes, dtypes, and device for tq4_sdpa."""
B, H_Q, N_Q, D = query.shape
B_kp, H_KV, N_KV, HALF_D = k_packed.shape
_check_rank(query, "query", 4, "[B, H, L, D]")
_check_rank(k_packed, "k_packed", 4, "[B, H, L, D//2]")
_check_rank(k_norms, "k_norms", 4, "[B, H, L, 1]")
_check_rank(v_packed, "v_packed", 4, "[B, H, L, D//2]")
_check_rank(v_norms, "v_norms", 4, "[B, H, L, 1]")
_check_rank(centroids, "centroids", 1, "[16]")
_check_rank(rotation, "rotation", 2, "[D, D]")
if attn_mask is not None:
_check_rank(attn_mask, "attn_mask", 4, "[B, 1, L_Q, L_KV]")
if kv_len is not None and kv_len.dim() > 1:
raise RuntimeError(
f"kv_len must be a scalar or 1D tensor with one element, "
f"got {kv_len.dim()}D"
)

if not query.is_cuda:
raise RuntimeError("query must be a CUDA tensor")
expected_device = query.device
for name, tensor in (
("k_packed", k_packed),
("k_norms", k_norms),
("v_packed", v_packed),
("v_norms", v_norms),
("centroids", centroids),
("rotation", rotation),
):
_check_same_cuda_device(tensor, name, expected_device)
if attn_mask is not None:
_check_same_cuda_device(attn_mask, "attn_mask", expected_device)
if kv_len is not None:
_check_same_cuda_device(kv_len, "kv_len", expected_device)

if query.dtype != torch.bfloat16:
raise RuntimeError(f"query must be bfloat16, got {query.dtype}")
if query.dim() != 4:
raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D")
if k_packed.dim() != 4 or v_packed.dim() != 4:
raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]")
if k_packed.dtype != torch.uint8 or v_packed.dtype != torch.uint8:
raise RuntimeError("k_packed and v_packed must be uint8")
if k_packed.dtype != torch.uint8:
raise RuntimeError(f"k_packed must be uint8, got {k_packed.dtype}")
if v_packed.dtype != torch.uint8:
raise RuntimeError(f"v_packed must be uint8, got {v_packed.dtype}")
if k_norms.dtype not in (torch.float32, torch.bfloat16):
raise RuntimeError(
f"k_norms must be float32 or bfloat16, got {k_norms.dtype}"
)
if v_norms.dtype not in (torch.float32, torch.bfloat16):
raise RuntimeError(
f"v_norms must be float32 or bfloat16, got {v_norms.dtype}"
)
if not torch.is_floating_point(centroids):
raise RuntimeError(f"centroids must be floating point, got {centroids.dtype}")
if not torch.is_floating_point(rotation):
raise RuntimeError(f"rotation must be floating point, got {rotation.dtype}")
if attn_mask is not None and attn_mask.dtype != torch.bool:
raise RuntimeError(
f"attn_mask must be bool, got {attn_mask.dtype}. "
"Additive float masks are not supported."
)
if kv_len is not None and kv_len.dtype not in (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
):
raise RuntimeError(f"kv_len must have integer dtype, got {kv_len.dtype}")

B, H_Q, N_Q, D = query.shape
B_kp, H_KV, N_KV, HALF_D = k_packed.shape

if v_packed.shape != k_packed.shape:
raise RuntimeError(
f"v_packed shape mismatch: expected {tuple(k_packed.shape)} to match "
f"k_packed, got {tuple(v_packed.shape)}"
)
if B_kp != B:
raise RuntimeError(
f"Batch dim mismatch: query has B={B}, k_packed has B={B_kp}"
)
if H_KV == 0:
raise RuntimeError("k_packed head dimension must be greater than 0")
if H_Q % H_KV != 0:
raise RuntimeError(
f"H_Q must be a multiple of H_KV for GQA head mapping, "
Expand All @@ -647,34 +732,46 @@ def _validate_tq4_inputs(query, k_packed, v_packed):
f"HEAD_DIM must be a power of 2, got {D}. "
"Non-power-of-2 head dims are not supported."
)


def _validate_tq4_mask(attn_mask, B, N_Q, N_KV):
"""Validate attention mask for tq4_sdpa."""
if attn_mask is None:
return
if attn_mask.dtype != torch.bool:
expected_norm_shape = (B, H_KV, N_KV, 1)
if k_norms.shape != expected_norm_shape:
raise RuntimeError(
f"attn_mask must be bool, got {attn_mask.dtype}. "
"Additive float masks are not supported."
f"k_norms shape mismatch: expected {expected_norm_shape} to match "
f"k_packed layout, got {tuple(k_norms.shape)}"
)
if not attn_mask.is_cuda:
raise RuntimeError("attn_mask must be a CUDA tensor")
if attn_mask.shape[1] != 1:
if v_norms.shape != expected_norm_shape:
raise RuntimeError(
f"attn_mask head dimension must be 1 (broadcast over heads); "
f"per-head masks are not supported. "
f"Got attn_mask.shape={attn_mask.shape}"
f"v_norms shape mismatch: expected {expected_norm_shape} to match "
f"v_packed layout, got {tuple(v_norms.shape)}"
)
if (
attn_mask.shape[0] != B
or attn_mask.shape[2] != N_Q
or attn_mask.shape[3] != N_KV
):
if centroids.shape != (16,):
raise RuntimeError(
f"attn_mask shape mismatch: expected "
f"[B={B}, 1, L_Q={N_Q}, L_KV={N_KV}], "
f"got {attn_mask.shape}"
f"centroids must have shape (16,), got {tuple(centroids.shape)}"
)
if rotation.shape != (D, D):
raise RuntimeError(
f"rotation must have shape ({D}, {D}), got {tuple(rotation.shape)}"
)

if attn_mask is not None:
if attn_mask.shape[1] != 1:
raise RuntimeError(
f"attn_mask head dimension must be 1 (broadcast over heads); "
f"per-head masks are not supported. "
f"Got attn_mask.shape={attn_mask.shape}"
)
if (
attn_mask.shape[0] != B
or attn_mask.shape[2] != N_Q
or attn_mask.shape[3] != N_KV
):
raise RuntimeError(
f"attn_mask shape mismatch: expected "
f"[B={B}, 1, L_Q={N_Q}, L_KV={N_KV}], "
f"got {attn_mask.shape}"
)
if kv_len is not None and kv_len.numel() != 1:
raise RuntimeError(
f"kv_len must contain exactly one element, got {kv_len.numel()}"
)


Expand Down Expand Up @@ -736,13 +833,21 @@ def tq4_sdpa(
Returns:
[B, H_Q, L_Q, D] bf16 attention output
"""
_validate_tq4_inputs(query, k_packed, v_packed)
_validate_tq4_inputs(
query,
k_packed,
k_norms,
v_packed,
v_norms,
centroids,
rotation,
attn_mask,
kv_len,
)

B, H_Q, N_Q, D = query.shape
_, H_KV, N_KV, HALF_D = k_packed.shape

_validate_tq4_mask(attn_mask, B, N_Q, N_KV)

sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale)
num_groups = H_Q // H_KV

Expand Down
Loading