-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathaddmm.py
More file actions
94 lines (78 loc) · 2.02 KB
/
addmm.py
File metadata and controls
94 lines (78 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
import ntops.kernels.mm as mm
def arrangement(
input,
x,
y,
beta,
alpha,
output,
input_precision,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if block_size_m is None:
block_size_m = mm.BLOCK_SIZE_M
if block_size_n is None:
block_size_n = mm.BLOCK_SIZE_N
if block_size_k is None:
block_size_k = mm.BLOCK_SIZE_K
_, _, input_arranged, _ = mm.arrangement(
x,
y,
input,
input_precision,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)
x_arranged, y_arranged, output_arranged, _ = mm.arrangement(
x,
y,
output,
input_precision,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)
input_precision_arranged = input_precision
return (
input_arranged,
x_arranged,
y_arranged,
beta,
alpha,
output_arranged,
input_precision_arranged,
)
def application(input, x, y, beta, alpha, output, input_precision):
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
mm.application(x, y, mm_output, input_precision)
output = beta * input + alpha * mm_output
def premake(
input_precision=None,
dtype=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
arrangement_ = functools.partial(
arrangement,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)
tensors = (
Tensor(2, dtype=dtype),
Tensor(2, dtype=dtype),
Tensor(2, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(2, dtype=dtype),
Tensor(0, constexpr=True, value=input_precision),
)
return arrangement_, application, tensors