-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmatmul_opt_sme2.s
More file actions
151 lines (114 loc) · 5.74 KB
/
matmul_opt_sme2.s
File metadata and controls
151 lines (114 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
.text
.global _sgemm_direct_sme2_2VLx2VL
_sgemm_direct_sme2_2VLx2VL:
//x0: M, x1: K, x2: N, x3: matLeft, x4: matRight_mod, x5: matResult
stp x19, x20, [sp, #-48]!
stp x21, x22, [sp, #16]
stp x23, x24, [sp, #32]
smstart
// constants
cntw x6 // SVLs
mul x22, x6, x1 // SVLs*K
mul x23, x6, x2 // SVLs*N
add x19, x23, x2 // SVLs*N + N
add x11, x4, x2, lsl #2 // Exit condition for N loop
mov x12, #0
cntb x6 // SVLb
mov x14, #0
ptrue pn10.b // Predicate for SME2 VLx2 (a_ptr loads)
whilelt pn8.s, x12, x0, vlx2 // tiles predicate (M dimension)
sub w6, w6, #8 // SVLb-8
.Loop_M:
// Extract tile 0/1 and tile 2/3 predicates (M) from vlx2 predicate.
pext { p2.s, p3.s }, pn8[0]
mov x16, x4 // b_base
mov x9, x5 // c_base
whilelt pn9.b, x16, x11, vlx2 // tiles predicate (N dimension)
.Loop_N:
mov x7, x3 // a_ptr = a_base
mov x17, x16 // b_ptr = b_base
mov x10, x9 // c_ptr0 = c_base
// Extract tile 0/2 and tile 1/3 predicates (N) from vlx2 predicate.
pext { p0.b, p1.b }, pn9[0]
add x8, x3, x22, lsl #2 // a_base + SVLs*K FP32 elms (bytes)
addvl x15, x8, #-1 // Exit condition for K loop
ld1w {z1.s}, p2/z, [x7] // Load 1st vector from a_ptr
zero {za}
ld1w {z2.s-z3.s}, pn9/z, [x17] // Load 2 vectors from b_ptr
fmopa za0.s, p2/m, p0/m, z1.s, z2.s // ZA0+=1st a_ptr vec OP 1st b_ptr vec
ld1w {z5.s}, p3/z, [x7, x22, lsl #2] // Load 2nd vector from a_ptr
addvl x7, x7, #1 // a_ptr += SVLb (bytes)
.Loop_K:
fmopa za2.s, p3/m, p0/m, z5.s, z2.s // ZA2+=2nd a_ptr vec OP 1st b_ptr vec
fmopa za1.s, p2/m, p1/m, z1.s, z3.s // ZA1+=1st a_ptr vec OP 2nd b_ptr vec
ld1w {z0.s-z1.s}, pn10/z, [x7] // Load next 2 vectors from a_ptr
fmopa za3.s, p3/m, p1/m, z5.s, z3.s // ZA3+=2nd a_ptr vec OP 2nd b_ptr vec
ld1w {z6.s-z7.s}, pn9/z, [x17, x2, lsl #2] // Load next 2 vecs from b_ptr
fmopa za0.s, p2/m, p0/m, z0.s, z6.s // ZA0+=1st a_ptr vec OP 1st b_ptr vec
psel pn11, pn10, p3.s[w14, 0] // Select predicate-as-counter
ld1w {z4.s-z5.s}, pn11/z, [x7, x22, lsl #2] // Load next 2 vecs from a_ptr
fmopa za2.s, p3/m, p0/m, z4.s, z6.s // ZA2+=2nd a_ptr vec OP 1st b_ptr vec
add x17, x17, x2, lsl #3 // b_ptr += 2*N FP32 elms (bytes)
fmopa za1.s, p2/m, p1/m, z0.s, z7.s // ZA1+=1st a_ptr vec OP 2nd b_ptr vec
fmopa za3.s, p3/m, p1/m, z4.s, z7.s // ZA3+=2nd a_ptr vec OP 2nd b_ptr vec
ld1w {z2.s-z3.s}, pn9/z, [x17] // Load next 2 vectors from b_ptr
fmopa za0.s, p2/m, p0/m, z1.s, z2.s // ZA0+=1st a_ptr vec OP 1st b_ptr vec
addvl x7, x7, #2 // a_ptr += 2*SVLb (bytes)
cmp x7, x15
b.mi .Loop_K
fmopa za2.s, p3/m, p0/m, z5.s, z2.s // ZA2+=2nd a_ptr vec OP 1st b_ptr vec
fmopa za1.s, p2/m, p1/m, z1.s, z3.s // ZA1+=1st a_ptr vec OP 2nd b_ptr vec
fmopa za3.s, p3/m, p1/m, z5.s, z3.s // ZA3+=2nd a_ptr vec OP 2nd b_ptr vec
add x17, x17, x2, lsl #2 // b_ptr += 2*N FP32 elms (bytes)
cmp x7, x8
b.pl .Ktail_end
.Ktail_start:
ld1w {z1.s}, p2/z, [x7]
ld1w {z2.s-z3.s}, pn9/z, [x17]
fmopa za0.s, p2/m, p0/m, z1.s, z2.s
ld1w {z5.s}, p3/z, [x7, x22, lsl #2]
fmopa za2.s, p3/m, p0/m, z5.s, z2.s
fmopa za1.s, p2/m, p1/m, z1.s, z3.s
fmopa za3.s, p3/m, p1/m, z5.s, z3.s
.Ktail_end:
mov w13, #0
psel pn11, pn9, p2.b[w13, 0]
psel pn12, pn9, p3.b[w13, 0]
// ZA tiles to vecs: z0 = za0h[1], z1 = za1h[1], z2 = za2h[1], z3 = za3h[1]
mova { z0.b-z3.b }, za0h.b[w13, 0:3]
st1w { z0.s-z1.s }, pn11, [x10] // Store to c_ptr0
st1w { z2.s-z3.s }, pn12, [x10, x23, lsl #2] // Store to c_ptr0+(SVLs*N)
.Loop_store_ZA:
psel pn11, pn9, p2.b[w13, 4]
psel pn12, pn9, p3.b[w13, 4]
mova { z0.b-z3.b }, za0h.b[w13, 4:7]
st1w { z0.s-z1.s }, pn11, [x10, x2, lsl #2] // Store to c_ptr0+N
st1w { z2.s-z3.s }, pn12, [x10, x19, lsl #2] // Store to c_ptr0+(SVLs+1)*N
add x10, x10, x2, lsl #3 // c_ptr0 += 2*N FP32 elms (bytes)
add w13, w13, #8
psel pn11, pn9, p2.b[w13, 0]
psel pn12, pn9, p3.b[w13, 0]
mova { z0.b-z3.b }, za0h.b[w13, 0:3]
st1w { z0.s-z1.s }, pn11, [x10] // Store to c_ptr0
st1w { z2.s-z3.s }, pn12, [x10, x23, lsl #2] // Store to c_ptr0+SVLs*N
cmp w13, w6
b.mi .Loop_store_ZA
psel pn11, pn9, p2.b[w13, 4]
psel pn12, pn9, p3.b[w13, 4]
mova { z0.b-z3.b }, za0h.b[w13, 4:7]
st1w { z0.s-z1.s }, pn11, [x10, x2, lsl #2] // Store to c_ptr0+N
st1w { z2.s-z3.s }, pn12, [x10, x19, lsl #2] // Store to c_ptr0+(SVLs+1)*N
addvl x9, x9, #2
addvl x16, x16, #2 // b_base += 2*SVLb (bytes)
whilelt pn9.b, x16, x11, vlx2 // tile predicate (N dimension)
b.first .Loop_N
add x3, x3, x22, lsl #3 // a_base += 2*SVLs*K FP32 elms (bytes)
add x5, x5, x23, lsl #3 // c_base += 2*SVLs*N FP32 elms (bytes)
incw x12, all, mul #2 // M loop counter += 2* SVLs
whilelt pn8.s, x12, x0, vlx2 // tiles predicate (M dimension)
b.first .Loop_M
smstop
ldp x23, x24, [sp, #32]
ldp x21, x22, [sp, #16]
ldp x19, x20, [sp], #48
ret