forked from gpu-mode/reference-kernels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreference.py
More file actions
62 lines (49 loc) · 2.2 KB
/
reference.py
File metadata and controls
62 lines (49 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Causal depthwise Conv1D — reference implementation.
Used in SSM-based LLM architectures such as Mamba, where a short causal
(left-padded) depthwise convolution mixes local context along the sequence
dimension independently per channel, before the selective state-space step.
"Causal" means output position t depends only on input positions <= t,
enforced by padding W-1 zeros on the left and no padding on the right.
This module provides a pure-PyTorch reference against which optimized
Triton/CUDA kernels are verified.
Shapes:
x : (B, D, S) — batch, channels (model dim), sequence length
weight : (D, W) — one filter of width W per channel (depthwise)
bias : (D,) — per-channel bias
output : (B, D, S) — same shape as input
"""
import torch
import torch.nn.functional as F
from task import input_t, output_t
from utils import make_match_reference, DeterministicContext
def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t:
"""Generate random (x, weight, bias) on CUDA with a fixed seed for reproducibility."""
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
x = torch.randn(B, D, S, dtype=torch.float32, device="cuda", generator=gen).contiguous()
weight = torch.randn(D, W, dtype=torch.float32, device="cuda", generator=gen).contiguous()
bias = torch.randn(D, dtype=torch.float32, device="cuda", generator=gen).contiguous()
return x, weight, bias
def ref_kernel(data: input_t) -> output_t:
"""
Causal depthwise Conv1D via PyTorch.
Pads W-1 zeros on the left of the sequence so that the convolution at
position t sees only x[:, :, t-W+1 : t+1], preserving causality.
groups=D makes each channel use its own filter (depthwise).
"""
with DeterministicContext():
x, weight, bias = data
B, D, S = x.shape
W = weight.shape[1]
# Causal (left) padding
x_padded = F.pad(x, (W - 1, 0))
# Depthwise conv1d (groups=D)
output = F.conv1d(
x_padded,
weight.unsqueeze(1), # [D, 1, W]
bias=bias,
groups=D,
)
return output
check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2)