Skip to content

Commit f3d7997

Browse files
committed
test(gpu-test): add flash attention end-to-end GPU test
1 parent 021a186 commit f3d7997

3 files changed

Lines changed: 117 additions & 0 deletions

File tree

gpu_test/test_kernels.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING
66

7+
import numpy as np
78
import pytest
89

910
if TYPE_CHECKING:
@@ -583,3 +584,116 @@ def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None:
583584
forth_source=("\\! kernel main\n\\! param DATA i64[256]\n7.9 F>S\n0 CELLS DATA + !"),
584585
)
585586
assert result[0] == 7
587+
588+
589+
# --- Attention ---
590+
591+
592+
def test_naive_attention_f64(kernel_runner: KernelRunner) -> None:
593+
"""Naive scaled dot-product attention with causal mask.
594+
595+
O = softmax(Q @ K^T / sqrt(d_k)) @ V, seq_len=4, head_dim=4.
596+
One block per query row, one thread per key position.
597+
"""
598+
seq_len, head_dim = 4, 4
599+
600+
q = np.array(
601+
[
602+
[1.0, 0.0, 1.0, 0.0],
603+
[0.0, 1.0, 0.0, 1.0],
604+
[1.0, 1.0, 0.0, 0.0],
605+
[0.0, 0.0, 1.0, 1.0],
606+
]
607+
)
608+
k = np.array(
609+
[
610+
[1.0, 0.0, 0.0, 1.0],
611+
[0.0, 1.0, 1.0, 0.0],
612+
[1.0, 1.0, 0.0, 0.0],
613+
[0.0, 0.0, 1.0, 1.0],
614+
]
615+
)
616+
v = np.array(
617+
[
618+
[1.0, 2.0, 3.0, 4.0],
619+
[5.0, 6.0, 7.0, 8.0],
620+
[9.0, 10.0, 11.0, 12.0],
621+
[13.0, 14.0, 15.0, 16.0],
622+
]
623+
)
624+
625+
# Reference: scaled dot-product attention with causal mask
626+
scores = q @ k.T / np.sqrt(head_dim)
627+
causal_mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1)
628+
scores[causal_mask] = -1e30
629+
exp_scores = np.exp(scores - scores.max(axis=1, keepdims=True))
630+
attn = exp_scores / exp_scores.sum(axis=1, keepdims=True)
631+
expected = (attn @ v).flatten().tolist()
632+
633+
result = kernel_runner.run(
634+
forth_source=(
635+
"\\! kernel attention\n"
636+
"\\! param Q f64[16]\n"
637+
"\\! param K f64[16]\n"
638+
"\\! param V f64[16]\n"
639+
"\\! param O f64[16]\n"
640+
"\\! param SEQ_LEN i64\n"
641+
"\\! param HEAD_DIM i64\n"
642+
"\\! shared SCORES f64[4]\n"
643+
"\\! shared SCRATCH f64[4]\n"
644+
"BID-X\n"
645+
"TID-X\n"
646+
"0.0\n"
647+
"HEAD_DIM 0 DO\n"
648+
" 2 PICK HEAD_DIM * I + CELLS Q + F@\n"
649+
" 2 PICK HEAD_DIM * I + CELLS K + F@\n"
650+
" F* F+\n"
651+
"LOOP\n"
652+
"HEAD_DIM S>F FSQRT F/\n"
653+
"OVER 3 PICK >\n"
654+
"IF DROP -1.0e30 THEN\n"
655+
"OVER CELLS SCORES + SF!\n"
656+
"BARRIER\n"
657+
"TID-X 0= IF\n"
658+
" 0 CELLS SCORES + SF@\n"
659+
" SEQ_LEN 1 DO I CELLS SCORES + SF@ FMAX LOOP\n"
660+
" 0 CELLS SCRATCH + SF!\n"
661+
"THEN\n"
662+
"BARRIER\n"
663+
"DUP CELLS SCORES + SF@\n"
664+
"0 CELLS SCRATCH + SF@\n"
665+
"F- FEXP\n"
666+
"OVER CELLS SCORES + SF!\n"
667+
"BARRIER\n"
668+
"TID-X 0= IF\n"
669+
" 0.0\n"
670+
" SEQ_LEN 0 DO I CELLS SCORES + SF@ F+ LOOP\n"
671+
" 0 CELLS SCRATCH + SF!\n"
672+
"THEN\n"
673+
"BARRIER\n"
674+
"DUP CELLS SCORES + SF@\n"
675+
"0 CELLS SCRATCH + SF@\n"
676+
"F/\n"
677+
"OVER CELLS SCORES + SF!\n"
678+
"BARRIER\n"
679+
"0.0\n"
680+
"SEQ_LEN 0 DO\n"
681+
" I CELLS SCORES + SF@\n"
682+
" I HEAD_DIM * 3 PICK + CELLS V + F@\n"
683+
" F* F+\n"
684+
"LOOP\n"
685+
"ROT HEAD_DIM * ROT + CELLS O + F!\n"
686+
),
687+
params={
688+
"Q": q.flatten().tolist(),
689+
"K": k.flatten().tolist(),
690+
"V": v.flatten().tolist(),
691+
"SEQ_LEN": seq_len,
692+
"HEAD_DIM": head_dim,
693+
},
694+
grid=(seq_len, 1, 1),
695+
block=(seq_len, 1, 1),
696+
output_param=3,
697+
output_count=seq_len * head_dim,
698+
)
699+
assert result == [pytest.approx(v) for v in expected]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44
requires-python = ">=3.11"
55
dependencies = [
66
"lit>=18.1.0",
7+
"numpy",
78
"pytest",
89
"vastai-sdk",
910
]

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)