-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathreference.py
More file actions
26 lines (21 loc) · 849 Bytes
/
reference.py
File metadata and controls
26 lines (21 loc) · 849 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from task import input_t, output_t
from utils import verbose_allclose
def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)
a = torch.empty(m, k, device='cuda', dtype=torch.float16)
a.uniform_(0, 1, generator=gen)
b = torch.empty(k, n, device='cuda', dtype=torch.float16)
b.uniform_(0, 1, generator=gen)
return (a, b)
def ref_kernel(data: input_t) -> output_t:
a, b = data
return a @ b
def check_implementation(data: input_t, output: output_t) -> str:
expected = ref_kernel(data)
reasons = verbose_allclose(output, expected)
if len(reasons) > 0:
# TODO better processing of reasons
return False, "mismatch found! custom implementation doesn't match reference.: " + reasons[0]
return True, ''