|
4 | 4 |
|
5 | 5 | from typing import TYPE_CHECKING |
6 | 6 |
|
| 7 | +import numpy as np |
7 | 8 | import pytest |
8 | 9 |
|
9 | 10 | if TYPE_CHECKING: |
@@ -583,3 +584,116 @@ def test_float_to_int_conversion(kernel_runner: KernelRunner) -> None: |
583 | 584 | forth_source=("\\! kernel main\n\\! param DATA i64[256]\n7.9 F>S\n0 CELLS DATA + !"), |
584 | 585 | ) |
585 | 586 | assert result[0] == 7 |
| 587 | + |
| 588 | + |
| 589 | +# --- Attention --- |
| 590 | + |
| 591 | + |
| 592 | +def test_naive_attention_f64(kernel_runner: KernelRunner) -> None: |
| 593 | + """Naive scaled dot-product attention with causal mask. |
| 594 | +
|
| 595 | + O = softmax(Q @ K^T / sqrt(d_k)) @ V, seq_len=4, head_dim=4. |
| 596 | + One block per query row, one thread per key position. |
| 597 | + """ |
| 598 | + seq_len, head_dim = 4, 4 |
| 599 | + |
| 600 | + q = np.array( |
| 601 | + [ |
| 602 | + [1.0, 0.0, 1.0, 0.0], |
| 603 | + [0.0, 1.0, 0.0, 1.0], |
| 604 | + [1.0, 1.0, 0.0, 0.0], |
| 605 | + [0.0, 0.0, 1.0, 1.0], |
| 606 | + ] |
| 607 | + ) |
| 608 | + k = np.array( |
| 609 | + [ |
| 610 | + [1.0, 0.0, 0.0, 1.0], |
| 611 | + [0.0, 1.0, 1.0, 0.0], |
| 612 | + [1.0, 1.0, 0.0, 0.0], |
| 613 | + [0.0, 0.0, 1.0, 1.0], |
| 614 | + ] |
| 615 | + ) |
| 616 | + v = np.array( |
| 617 | + [ |
| 618 | + [1.0, 2.0, 3.0, 4.0], |
| 619 | + [5.0, 6.0, 7.0, 8.0], |
| 620 | + [9.0, 10.0, 11.0, 12.0], |
| 621 | + [13.0, 14.0, 15.0, 16.0], |
| 622 | + ] |
| 623 | + ) |
| 624 | + |
| 625 | + # Reference: scaled dot-product attention with causal mask |
| 626 | + scores = q @ k.T / np.sqrt(head_dim) |
| 627 | + causal_mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1) |
| 628 | + scores[causal_mask] = -1e30 |
| 629 | + exp_scores = np.exp(scores - scores.max(axis=1, keepdims=True)) |
| 630 | + attn = exp_scores / exp_scores.sum(axis=1, keepdims=True) |
| 631 | + expected = (attn @ v).flatten().tolist() |
| 632 | + |
| 633 | + result = kernel_runner.run( |
| 634 | + forth_source=( |
| 635 | + "\\! kernel attention\n" |
| 636 | + "\\! param Q f64[16]\n" |
| 637 | + "\\! param K f64[16]\n" |
| 638 | + "\\! param V f64[16]\n" |
| 639 | + "\\! param O f64[16]\n" |
| 640 | + "\\! param SEQ_LEN i64\n" |
| 641 | + "\\! param HEAD_DIM i64\n" |
| 642 | + "\\! shared SCORES f64[4]\n" |
| 643 | + "\\! shared SCRATCH f64[4]\n" |
| 644 | + "BID-X\n" |
| 645 | + "TID-X\n" |
| 646 | + "0.0\n" |
| 647 | + "HEAD_DIM 0 DO\n" |
| 648 | + " 2 PICK HEAD_DIM * I + CELLS Q + F@\n" |
| 649 | + " 2 PICK HEAD_DIM * I + CELLS K + F@\n" |
| 650 | + " F* F+\n" |
| 651 | + "LOOP\n" |
| 652 | + "HEAD_DIM S>F FSQRT F/\n" |
| 653 | + "OVER 3 PICK >\n" |
| 654 | + "IF DROP -1.0e30 THEN\n" |
| 655 | + "OVER CELLS SCORES + SF!\n" |
| 656 | + "BARRIER\n" |
| 657 | + "TID-X 0= IF\n" |
| 658 | + " 0 CELLS SCORES + SF@\n" |
| 659 | + " SEQ_LEN 1 DO I CELLS SCORES + SF@ FMAX LOOP\n" |
| 660 | + " 0 CELLS SCRATCH + SF!\n" |
| 661 | + "THEN\n" |
| 662 | + "BARRIER\n" |
| 663 | + "DUP CELLS SCORES + SF@\n" |
| 664 | + "0 CELLS SCRATCH + SF@\n" |
| 665 | + "F- FEXP\n" |
| 666 | + "OVER CELLS SCORES + SF!\n" |
| 667 | + "BARRIER\n" |
| 668 | + "TID-X 0= IF\n" |
| 669 | + " 0.0\n" |
| 670 | + " SEQ_LEN 0 DO I CELLS SCORES + SF@ F+ LOOP\n" |
| 671 | + " 0 CELLS SCRATCH + SF!\n" |
| 672 | + "THEN\n" |
| 673 | + "BARRIER\n" |
| 674 | + "DUP CELLS SCORES + SF@\n" |
| 675 | + "0 CELLS SCRATCH + SF@\n" |
| 676 | + "F/\n" |
| 677 | + "OVER CELLS SCORES + SF!\n" |
| 678 | + "BARRIER\n" |
| 679 | + "0.0\n" |
| 680 | + "SEQ_LEN 0 DO\n" |
| 681 | + " I CELLS SCORES + SF@\n" |
| 682 | + " I HEAD_DIM * 3 PICK + CELLS V + F@\n" |
| 683 | + " F* F+\n" |
| 684 | + "LOOP\n" |
| 685 | + "ROT HEAD_DIM * ROT + CELLS O + F!\n" |
| 686 | + ), |
| 687 | + params={ |
| 688 | + "Q": q.flatten().tolist(), |
| 689 | + "K": k.flatten().tolist(), |
| 690 | + "V": v.flatten().tolist(), |
| 691 | + "SEQ_LEN": seq_len, |
| 692 | + "HEAD_DIM": head_dim, |
| 693 | + }, |
| 694 | + grid=(seq_len, 1, 1), |
| 695 | + block=(seq_len, 1, 1), |
| 696 | + output_param=3, |
| 697 | + output_count=seq_len * head_dim, |
| 698 | + ) |
| 699 | + assert result == [pytest.approx(v) for v in expected] |
0 commit comments