Skip to content

Commit 7f82e8f

Browse files
committed
feat(demo): add GPT-2 interactive inference with WarpForth attention kernel
1 parent cb95a17 commit 7f82e8f

6 files changed

Lines changed: 511 additions & 0 deletions

File tree

demo/README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# GPT-2 Demo
2+
3+
Text generation with GPT-2-124M using a WarpForth-compiled attention kernel. The stock Hugging Face model is loaded normally, then `eager_attention_forward` is monkey-patched to route scaled dot-product attention through a WarpForth kernel compiled to PTX. PyCUDA shares PyTorch's CUDA context via `autoprimaryctx`, so device pointers pass directly between the two — no copies, no CPU roundtrips.
4+
5+
## Prerequisites
6+
7+
- WarpForth built locally (`cmake --build build`)
8+
- A Vast.ai GPU instance with a PyTorch image (e.g. `pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime`)
9+
10+
## Step 1: Compile the Kernel (Local)
11+
12+
```bash
13+
./build/bin/warpforthc demo/attention.forth > demo/attention.ptx
14+
```
15+
16+
A pre-compiled `attention.ptx` is included in this directory.
17+
18+
## Step 2: Upload to GPU Instance
19+
20+
```bash
21+
scp -r demo/ demo/gpt2_generate.py root@HOST:/workspace
22+
```
23+
24+
## Step 3: Install Dependencies (Remote)
25+
26+
```bash
27+
pip install pycuda transformers
28+
```
29+
30+
## Step 4: Generate Text (Remote)
31+
32+
```bash
33+
python /workspace/gpt2_generate.py --ptx /workspace/attention.ptx --prompt "The meaning of life is"
34+
```
35+
36+
| Flag | Default | Description |
37+
|------|---------|-------------|
38+
| `--ptx` | (required) | Path to compiled `attention.ptx` |
39+
| `--prompt` | `"The meaning of life is"` | Input text prompt |
40+
| `--max-tokens` | `100` | Maximum new tokens to generate |
41+
42+
## Limitations
43+
44+
- **Batch size 1** — the kernel processes one sequence at a time
45+
- **No KV cache** — all positions are recomputed each step (`use_cache=False`)
46+
- **Max sequence length 1024** — limited by shared memory allocation
47+
- **12 kernel launches per layer** — one per attention head
48+
49+
## Files
50+
51+
| File | Description |
52+
|------|-------------|
53+
| `attention.forth` | Attention kernel source (f32 global, f64 shared) |
54+
| `attention.ptx` | Pre-compiled PTX |
55+
| `warpforth.py` | PyCUDA wrapper for loading and launching the kernel |
56+
| `gpt2_generate.py` | Loads GPT-2, patches attention, generates text |

demo/attention.forth

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
\ GPT-2 attention kernel with f32 global memory, f64 shared memory for softmax.
2+
\ Adapted from test/Pipeline/attention.forth — 4 lines changed for f32 access.
3+
\
4+
\ Q/K/V/O are f32 arrays passed as raw byte buffers (i64[]).
5+
\ Global loads/stores use F32@/F32! with 4* byte addressing (f32 = 4 bytes).
6+
\ Shared memory stays f64 for softmax precision, using SF@/SF! with CELLS.
7+
8+
\! kernel attention
9+
\! param Q i64[32768]
10+
\! param K i64[32768]
11+
\! param V i64[32768]
12+
\! param O i64[32768]
13+
\! param SEQ_LEN i64
14+
\! param HEAD_DIM i64
15+
\! shared SCORES f64[1024]
16+
\! shared SCRATCH f64[1024]
17+
18+
\ row = BID-X, t = TID-X
19+
BID-X
20+
TID-X
21+
22+
\ --- Dot product: Q[row,:] . K[t,:] ---
23+
0.0
24+
HEAD_DIM 0 DO
25+
2 PICK HEAD_DIM * I + 4 * Q + F32@
26+
2 PICK HEAD_DIM * I + 4 * K + F32@
27+
F* F+
28+
LOOP
29+
HEAD_DIM S>F FSQRT F/
30+
31+
\ --- Causal mask: if t > row, score = -inf ---
32+
OVER 3 PICK >
33+
IF DROP -1.0e30 THEN
34+
35+
\ --- Store score to shared memory ---
36+
OVER CELLS SCORES + SF!
37+
BARRIER
38+
39+
\ --- Softmax: max reduction (thread 0) ---
40+
TID-X 0= IF
41+
0 CELLS SCORES + SF@
42+
SEQ_LEN 1 DO I CELLS SCORES + SF@ FMAX LOOP
43+
0 CELLS SCRATCH + SF!
44+
THEN
45+
BARRIER
46+
47+
\ --- Softmax: exp(score - max) ---
48+
DUP CELLS SCORES + SF@
49+
0 CELLS SCRATCH + SF@
50+
F- FEXP
51+
OVER CELLS SCORES + SF!
52+
BARRIER
53+
54+
\ --- Softmax: sum reduction (thread 0) ---
55+
TID-X 0= IF
56+
0.0
57+
SEQ_LEN 0 DO I CELLS SCORES + SF@ F+ LOOP
58+
0 CELLS SCRATCH + SF!
59+
THEN
60+
BARRIER
61+
62+
\ --- Softmax: normalize ---
63+
DUP CELLS SCORES + SF@
64+
0 CELLS SCRATCH + SF@
65+
F/
66+
OVER CELLS SCORES + SF!
67+
BARRIER
68+
69+
\ --- V accumulation: O[row,col] = sum_j SCORES[j] * V[j*HD + col] ---
70+
\ Stride over head_dim columns: col = t, t+BDIM-X, t+2*BDIM-X, ...
71+
DUP BEGIN DUP HEAD_DIM < WHILE
72+
0.0
73+
SEQ_LEN 0 DO
74+
I CELLS SCORES + SF@
75+
I HEAD_DIM * 3 PICK + 4 * V + F32@
76+
F* F+
77+
LOOP
78+
OVER 4 PICK HEAD_DIM * + 4 * O + F32!
79+
BDIM-X +
80+
REPEAT
81+
DROP DROP DROP

demo/attention.ptx

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
//
2+
// Generated by LLVM NVPTX Back-End
3+
//
4+
5+
.version 6.0
6+
.target sm_70
7+
.address_size 64
8+
9+
// .globl attention
10+
// __wg_attention_0 has been demoted
11+
// __wg_attention_1_$_0 has been demoted
12+
13+
.visible .entry attention(
14+
.param .u64 .ptr .align 1 attention_param_0,
15+
.param .u64 .ptr .align 1 attention_param_1,
16+
.param .u64 .ptr .align 1 attention_param_2,
17+
.param .u64 .ptr .align 1 attention_param_3,
18+
.param .u64 attention_param_4,
19+
.param .u64 attention_param_5
20+
)
21+
{
22+
.reg .pred %p<17>;
23+
.reg .b32 %r<21>;
24+
.reg .f32 %f<7>;
25+
.reg .b64 %rd<76>;
26+
.reg .f64 %fd<63>;
27+
// demoted variable
28+
.shared .align 8 .b8 __wg_attention_0[8192];
29+
// demoted variable
30+
.shared .align 8 .f64 __wg_attention_1_$_0;
31+
ld.param.u64 %rd41, [attention_param_5];
32+
ld.param.u64 %rd40, [attention_param_4];
33+
ld.param.u64 %rd39, [attention_param_3];
34+
ld.param.u64 %rd38, [attention_param_2];
35+
ld.param.u64 %rd42, [attention_param_0];
36+
mov.u32 %r1, %ctaid.x;
37+
ld.param.u64 %rd43, [attention_param_1];
38+
cvt.u64.u32 %rd44, %r1;
39+
mov.u32 %r5, %tid.x;
40+
cvt.u64.u32 %rd72, %r5;
41+
mul.lo.s64 %rd2, %rd41, %rd44;
42+
mul.lo.s64 %rd45, %rd41, %rd72;
43+
neg.s64 %rd66, %rd41;
44+
shl.b64 %rd46, %rd45, 2;
45+
add.s64 %rd65, %rd43, %rd46;
46+
shl.b64 %rd47, %rd2, 2;
47+
add.s64 %rd64, %rd42, %rd47;
48+
mov.f64 %fd58, 0d0000000000000000;
49+
$L__BB0_1:
50+
ld.f32 %f2, [%rd64];
51+
cvt.f64.f32 %fd16, %f2;
52+
ld.f32 %f3, [%rd65];
53+
cvt.f64.f32 %fd17, %f3;
54+
mul.rn.f64 %fd18, %fd16, %fd17;
55+
add.rn.f64 %fd58, %fd58, %fd18;
56+
add.s64 %rd9, %rd66, 1;
57+
xor.b64 %rd48, %rd9, %rd66;
58+
setp.gt.s64 %p1, %rd48, -1;
59+
add.s64 %rd65, %rd65, 4;
60+
add.s64 %rd64, %rd64, 4;
61+
mov.u64 %rd66, %rd9;
62+
@%p1 bra $L__BB0_1;
63+
cvt.u32.u64 %r6, %rd72;
64+
cvt.rn.f64.s64 %fd19, %rd41;
65+
sqrt.rn.f64 %fd20, %fd19;
66+
div.rn.f64 %fd21, %fd58, %fd20;
67+
setp.gt.u32 %p2, %r6, %r1;
68+
shl.b64 %rd49, %rd72, 3;
69+
mov.u64 %rd50, __wg_attention_0;
70+
add.s64 %rd12, %rd49, %rd50;
71+
selp.f64 %fd22, 0dC6293E5939A08CEA, %fd21, %p2;
72+
st.shared.f64 [%rd12], %fd22;
73+
bar.sync 0;
74+
setp.ne.s32 %p3, %r6, 0;
75+
@%p3 bra $L__BB0_6;
76+
ld.shared.f64 %fd60, [__wg_attention_0];
77+
mov.u64 %rd52, __wg_attention_0;
78+
add.s64 %rd68, %rd52, 8;
79+
sub.s64 %rd67, 1, %rd40;
80+
$L__BB0_4:
81+
ld.shared.f64 %fd23, [%rd68];
82+
mov.b64 %rd53, %fd23;
83+
setp.nan.f64 %p4, %fd60, %fd23;
84+
max.f64 %fd24, %fd60, %fd23;
85+
selp.f64 %fd25, 0d7FF8000000000000, %fd24, %p4;
86+
mov.b64 %rd54, %fd60;
87+
setp.eq.s64 %p5, %rd54, 0;
88+
selp.f64 %fd26, %fd60, %fd25, %p5;
89+
setp.eq.s64 %p6, %rd53, 0;
90+
selp.f64 %fd27, %fd23, %fd26, %p6;
91+
setp.eq.f64 %p7, %fd25, 0d0000000000000000;
92+
selp.f64 %fd60, %fd27, %fd25, %p7;
93+
add.s64 %rd17, %rd67, 1;
94+
xor.b64 %rd55, %rd17, %rd67;
95+
setp.gt.s64 %p8, %rd55, -1;
96+
add.s64 %rd68, %rd68, 8;
97+
mov.u64 %rd67, %rd17;
98+
@%p8 bra $L__BB0_4;
99+
st.shared.f64 [__wg_attention_1_$_0], %fd60;
100+
$L__BB0_6:
101+
bar.sync 0;
102+
ld.shared.f64 %fd28, [%rd12];
103+
ld.shared.f64 %fd29, [__wg_attention_1_$_0];
104+
sub.rn.f64 %fd4, %fd28, %fd29;
105+
fma.rn.f64 %fd30, %fd4, 0d3FF71547652B82FE, 0d4338000000000000;
106+
{
107+
.reg .b32 %temp;
108+
mov.b64 {%r2, %temp}, %fd30;
109+
}
110+
mov.f64 %fd31, 0dC338000000000000;
111+
add.rn.f64 %fd32, %fd30, %fd31;
112+
fma.rn.f64 %fd33, %fd32, 0dBFE62E42FEFA39EF, %fd4;
113+
fma.rn.f64 %fd34, %fd32, 0dBC7ABC9E3B39803F, %fd33;
114+
fma.rn.f64 %fd35, %fd34, 0d3E5ADE1569CE2BDF, 0d3E928AF3FCA213EA;
115+
fma.rn.f64 %fd36, %fd35, %fd34, 0d3EC71DEE62401315;
116+
fma.rn.f64 %fd37, %fd36, %fd34, 0d3EFA01997C89EB71;
117+
fma.rn.f64 %fd38, %fd37, %fd34, 0d3F2A01A014761F65;
118+
fma.rn.f64 %fd39, %fd38, %fd34, 0d3F56C16C1852B7AF;
119+
fma.rn.f64 %fd40, %fd39, %fd34, 0d3F81111111122322;
120+
fma.rn.f64 %fd41, %fd40, %fd34, 0d3FA55555555502A1;
121+
fma.rn.f64 %fd42, %fd41, %fd34, 0d3FC5555555555511;
122+
fma.rn.f64 %fd43, %fd42, %fd34, 0d3FE000000000000B;
123+
fma.rn.f64 %fd44, %fd43, %fd34, 0d3FF0000000000000;
124+
fma.rn.f64 %fd45, %fd44, %fd34, 0d3FF0000000000000;
125+
{
126+
.reg .b32 %temp;
127+
mov.b64 {%r3, %temp}, %fd45;
128+
}
129+
{
130+
.reg .b32 %temp;
131+
mov.b64 {%temp, %r4}, %fd45;
132+
}
133+
shl.b32 %r7, %r2, 20;
134+
add.s32 %r8, %r4, %r7;
135+
mov.b64 %fd59, {%r3, %r8};
136+
{
137+
.reg .b32 %temp;
138+
mov.b64 {%temp, %r9}, %fd4;
139+
}
140+
mov.b32 %f4, %r9;
141+
abs.f32 %f1, %f4;
142+
setp.lt.f32 %p9, %f1, 0f4086232B;
143+
@%p9 bra $L__BB0_9;
144+
setp.lt.f64 %p10, %fd4, 0d0000000000000000;
145+
add.rn.f64 %fd46, %fd4, 0d7FF0000000000000;
146+
selp.f64 %fd59, 0d0000000000000000, %fd46, %p10;
147+
setp.geu.f32 %p11, %f1, 0f40874800;
148+
@%p11 bra $L__BB0_9;
149+
shr.u32 %r10, %r2, 31;
150+
add.s32 %r11, %r2, %r10;
151+
shr.s32 %r12, %r11, 1;
152+
shl.b32 %r13, %r12, 20;
153+
add.s32 %r14, %r4, %r13;
154+
mov.b64 %fd47, {%r3, %r14};
155+
sub.s32 %r15, %r2, %r12;
156+
shl.b32 %r16, %r15, 20;
157+
add.s32 %r17, %r16, 1072693248;
158+
mov.b32 %r18, 0;
159+
mov.b64 %fd48, {%r18, %r17};
160+
mul.rn.f64 %fd59, %fd48, %fd47;
161+
$L__BB0_9:
162+
st.shared.f64 [%rd12], %fd59;
163+
bar.sync 0;
164+
@%p3 bra $L__BB0_13;
165+
neg.s64 %rd69, %rd40;
166+
mov.f64 %fd61, 0d0000000000000000;
167+
mov.u64 %rd70, __wg_attention_0;
168+
$L__BB0_11:
169+
ld.shared.f64 %fd50, [%rd70];
170+
add.rn.f64 %fd61, %fd61, %fd50;
171+
add.s64 %rd26, %rd69, 1;
172+
xor.b64 %rd57, %rd26, %rd69;
173+
setp.gt.s64 %p13, %rd57, -1;
174+
add.s64 %rd70, %rd70, 8;
175+
mov.u64 %rd69, %rd26;
176+
@%p13 bra $L__BB0_11;
177+
st.shared.f64 [__wg_attention_1_$_0], %fd61;
178+
$L__BB0_13:
179+
bar.sync 0;
180+
ld.shared.f64 %fd51, [%rd12];
181+
ld.shared.f64 %fd52, [__wg_attention_1_$_0];
182+
div.rn.f64 %fd53, %fd51, %fd52;
183+
st.shared.f64 [%rd12], %fd53;
184+
bar.sync 0;
185+
setp.le.s64 %p14, %rd41, %rd72;
186+
@%p14 bra $L__BB0_18;
187+
mov.u32 %r20, %ntid.x;
188+
cvt.u64.u32 %rd19, %r20;
189+
shl.b64 %rd58, %rd72, 2;
190+
add.s64 %rd71, %rd38, %rd58;
191+
mul.wide.u32 %rd21, %r20, 4;
192+
shl.b64 %rd22, %rd41, 2;
193+
neg.s64 %rd23, %rd40;
194+
$L__BB0_15:
195+
mov.f64 %fd62, 0d0000000000000000;
196+
mov.u64 %rd73, %rd23;
197+
mov.u64 %rd74, %rd71;
198+
mov.u64 %rd75, %rd50;
199+
$L__BB0_16:
200+
ld.shared.f64 %fd55, [%rd75];
201+
ld.f32 %f5, [%rd74];
202+
cvt.f64.f32 %fd56, %f5;
203+
mul.rn.f64 %fd57, %fd55, %fd56;
204+
add.rn.f64 %fd62, %fd62, %fd57;
205+
add.s64 %rd33, %rd73, 1;
206+
xor.b64 %rd60, %rd33, %rd73;
207+
setp.gt.s64 %p15, %rd60, -1;
208+
add.s64 %rd75, %rd75, 8;
209+
add.s64 %rd74, %rd74, %rd22;
210+
mov.u64 %rd73, %rd33;
211+
@%p15 bra $L__BB0_16;
212+
add.s64 %rd61, %rd72, %rd2;
213+
shl.b64 %rd62, %rd61, 2;
214+
add.s64 %rd63, %rd62, %rd39;
215+
cvt.rn.f32.f64 %f6, %fd62;
216+
st.f32 [%rd63], %f6;
217+
add.s64 %rd72, %rd72, %rd19;
218+
add.s64 %rd71, %rd71, %rd21;
219+
setp.lt.s64 %p16, %rd72, %rd41;
220+
@%p16 bra $L__BB0_15;
221+
$L__BB0_18:
222+
ret;
223+
224+
}

0 commit comments

Comments
 (0)