diff --git a/gimmik/hip.py b/gimmik/hip.py index a58c8fa..d11a08c 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -2,29 +2,143 @@ from gimmik.base import MatMul +import numpy as np + class HIPMatMul(MatMul): platform = 'hip' basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0} def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64): - # B loading, C streaming kernel - yield ('cstream', {}, {}) + max_block_threads = 1024 + max_shared = 64*1024 + + def emit(name, args, meta): + block = meta.get('block', self.basemeta['block']) + shared = meta.get('shared', self.basemeta['shared']) + threads = block[0]*block[1]*block[2] + + if threads <= max_block_threads and shared <= max_shared: + yield (name, args, meta) - # B streaming, C accumulation kernel - yield ('bstream', {}, {}) + def emit_preload(name, args, meta): + yield from emit(name, args | {'preload': True}, meta) - # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 64 args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} - yield ('bstream-msplit', args, meta) + meta = { + 'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}' + } + yield from emit('bstream-msplit', args, meta) - # Two-way k-split B loading, C streaming kernel ks, csz, blkx = 2, 24, 64 args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} - yield ('cstream-ksplit', args, meta) + meta = { + 'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}' + } + yield from emit('cstream-ksplit', args, meta) + + if dsize == 8: + blkx = 64 + a_hex, m_tiles, k_tiles, amask = self._mfma_dense_bake() + bix_rows = sorted(self.bix) # k-rows A actually uses + vec2_opts = [(False, '')] + if self.aligne is not None and self.aligne % 2 == 0: + vec2_opts.insert(0, (True, 'w2-')) + + for vec2, wpfx in vec2_opts: + for kc in [8, 16]: + shared = kc*4*blkx*dsize + for ms in [8, 16]: + args = { + 'blockx': blkx, 'a_hex': a_hex, + 'm_tiles': m_tiles, 'k_tiles': k_tiles, + 'amask': amask, 'msplit': ms, + 'bix_rows': bix_rows, 'vec2': vec2, 'kc': kc + } + meta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': ( + f'mfma-dense-msplit/{wpfx}' + f'm{m_tiles}-k{k_tiles}-s{ms}-kc{kc}-x{blkx}' + ) + } + yield from emit('mfma-dense-msplit', args, meta) + + # Tuned HIP variants + msplits, ksplits = [8, 4], [4, 2] + bsz, csz, blkx = 8, 8, 64 + widths = [1] + if self.aligne is not None and self.aligne % 2 == 0: + widths.insert(0, 2) + + for width in widths: + wargs = ({'dtype': f'{dtype}{width}', 'width': width} + if width > 1 else {}) + wmeta = {'width': width} if width > 1 else {} + wpfx = f'w{width}-' if width > 1 else '' + + for ms in msplits: + # m-split B streaming, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs + shared = 2*bsz*blkx*dsize*width + meta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'bstream-msplit/{wpfx}m{ms}-b{bsz}-x{blkx}' + } | wmeta + yield from emit('bstream-msplit', args, meta) + + for ms in msplits: + # m-split B streaming, C preloading, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs + shared = 2*bsz*blkx*dsize*width + meta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': ( + f'bstream-msplit-preload-c/' + f'{wpfx}m{ms}-b{bsz}-x{blkx}' + ) + } | wmeta + yield from emit_preload('bstream-msplit', args, meta) + + for ks in ksplits: + # k-split B loading, C preloading, C streaming kernel + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs + shared = (ks - 1)*csz*blkx*dsize*width + meta = { + 'block': (blkx, ks, 1), 'shared': shared, + 'desc': ( + f'cstream-ksplit-preload-c/' + f'{wpfx}k{ks}-c{csz}-x{blkx}' + ) + } | wmeta + yield from emit_preload('cstream-ksplit', args, meta) + + def _mfma_dense_bake(self): + # Densify, pad and reorder A into v_mfma_f64_16x16x4 fragment order: + # Ag[(mt*k_tiles + kt)*64 + lane] + # = A_pad[mt*16 + lane%16][kt*4 + lane//16] + # i.e. with lane = g*16 + p, operand A wants i = p, kk = g. + # amask[mt][kt] flags 16x4 A-tiles that contain a non-zero, so the + # kernel can skip the MMA (and, on the direct path, the B load) for + # all-zero tiles -- structural zero-tile skipping. + m, k = self.A.shape + m_tiles = -(-m // 16) + k_tiles = -(-k // 4) + a_pad = np.zeros((m_tiles*16, k_tiles*4), dtype=np.float64) + a_pad[:m, :k] = self.A + a_hex = [] + for mt in range(m_tiles): + for kt in range(k_tiles): + for lane in range(64): + i = mt*16 + (lane % 16) + kk = kt*4 + (lane // 16) + a_hex.append(float(a_pad[i, kk]).hex()) + amask = [[bool(np.any(a_pad[mt*16:mt*16+16, kt*4:kt*4+4])) + for kt in range(k_tiles)] for mt in range(m_tiles)] + return a_hex, m_tiles, k_tiles, amask def _process_meta(self, meta): if self.n is not None: diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index 874fbbd..d67ee25 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -1,12 +1,74 @@ % if dtype.endswith('4'): -static inline __device__ ${dtype} make_zero() +inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) +{ return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } + +inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) +{ return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); } + +inline __device__ ${dtype} make_zero() { return make_${dtype}(0, 0, 0, 0); } % elif dtype.endswith('2'): -static inline __device__ ${dtype} make_zero() +inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) +{ return make_${dtype}(a.x + b.x, a.y + b.y); } + +inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) +{ return make_${dtype}(a*b.x, a*b.y); } + +inline __device__ ${dtype} make_zero() { return make_${dtype}(0, 0); } % else: -static inline __device__ ${dtype} make_zero() +inline __device__ ${dtype} make_zero() { return 0; } % endif +static inline __device__ void +nt_store(${dtype}* p, ${dtype} v) +{ +% if dtype.endswith('4'): + __builtin_nontemporal_store(v.x, &p->x); + __builtin_nontemporal_store(v.y, &p->y); + __builtin_nontemporal_store(v.z, &p->z); + __builtin_nontemporal_store(v.w, &p->w); +% elif dtype.endswith('2'): + __builtin_nontemporal_store(v.x, &p->x); + __builtin_nontemporal_store(v.y, &p->y); +% else: + __builtin_nontemporal_store(v, p); +% endif +} + +static inline __device__ ${dtype} +nt_load(const ${dtype}* p) +{ +% if dtype.endswith('4'): + return make_${dtype}(__builtin_nontemporal_load(&p->x), + __builtin_nontemporal_load(&p->y), + __builtin_nontemporal_load(&p->z), + __builtin_nontemporal_load(&p->w)); +% elif dtype.endswith('2'): + return make_${dtype}(__builtin_nontemporal_load(&p->x), + __builtin_nontemporal_load(&p->y)); +% else: + return __builtin_nontemporal_load(p); +% endif +} + +static inline __device__ void +store_c(${dtype}* p, ${dtype} v) +{ + nt_store(p, v); +} + +static inline __device__ ${dtype} +load_c(const ${dtype}* p) +{ + return nt_load(p); +} + +static inline __device__ ${dtype} +load_b(const ${dtype}* p) +{ + return nt_load(p); +} + ${next.body()} diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 6359ca1..52853f4 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -3,6 +3,7 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) +preload = context.get('preload', False) %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -12,7 +13,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -34,9 +35,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { % for kx in bchunks[0]: % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor + + % if preload and beta != 0: + ## Preload C values for active rows owned by this m-split lane + % for j, jx in enumerate(mx[cid]): + % if afix[jx] != -1: + % if beta == 1: + csub[${j}] = load_c(&c[i + ${jx}*ldc]); + % else: + csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]); + % endif + % endif + % endfor + % endif } % endfor __syncthreads(); @@ -51,7 +65,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if not loop.parent.last: % for kx in bchunks[bb + 1]: % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor % endif @@ -59,18 +73,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for kx in bchunks[bb]: bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; % for j, jx in enumerate(A[mcx, kx]): - % if jx != 0 and kx == afix[mcx[j]]: + % if preload and beta != 0 and jx != 0: + csub[${j}] += ${jx}*bv; + % elif jx != 0 and kx == afix[mcx[j]]: csub[${j}] = ${jx}*bv; % elif jx != 0: csub[${j}] += ${jx}*bv; % endif ## If we're done with this dot product then store to global - % if kx == alix[mcx[j]] and beta == 0: - c[i + ${mcx[j]}*ldc] = csub[${j}]; + % if preload and kx == alix[mcx[j]]: + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + % elif kx == alix[mcx[j]] and beta == 0: + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % elif kx == alix[mcx[j]] and beta == 1: - c[i + ${mcx[j]}*ldc] += csub[${j}]; + store_c(&c[i + ${mcx[j]}*ldc], load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]); % elif kx == alix[mcx[j]]: - c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc]; + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${mcx[j]}*ldc])); % endif % endfor % endfor @@ -78,9 +96,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if loop.parent.last: % for j, jx in enumerate(afix): % if jx == -1 and j % msplit == cid and beta == 0: - c[i + ${j}*ldc] = make_zero(); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - c[i + ${j}*ldc] *= ${beta}; + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream.mako b/gimmik/kernels/hip/bstream.mako index 2f6dc62..1e7a70b 100644 --- a/gimmik/kernels/hip/bstream.mako +++ b/gimmik/kernels/hip/bstream.mako @@ -1,13 +1,15 @@ <%inherit file='base'/> -__global__ __launch_bounds__(128) void +<% preload = context.get('preload', False) %> + +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -24,22 +26,39 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { ${dtype} bv, csub[${m}]; -## Iterare through the used rows of B +% if preload and beta != 0: +## Preload C values for rows which will receive a non-zero dot product +% for j, jx in enumerate(afix): + % if jx != -1: + % if beta == 1: + csub[${j}] = load_c(&c[i + ${j}*ldc]); + % else: + csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]); + % endif + % endif +% endfor +% endif + +## Iterate through the used rows of B % for kx in bix: - bv = b[i + ${kx}*ldb]; + bv = load_b(&b[i + ${kx}*ldb]); % for j, jx in enumerate(A[:, kx]): - % if jx != 0 and kx == afix[j]: + % if preload and beta != 0 and jx != 0: + csub[${j}] += ${jx}*bv; + % elif jx != 0 and kx == afix[j]: csub[${j}] = ${jx}*bv; % elif jx != 0: csub[${j}] += ${jx}*bv; % endif ## - % if kx == alix[j] and beta == 0: - c[i + ${j}*ldc] = csub[${j}]; + % if preload and kx == alix[j]: + store_c(&c[i + ${j}*ldc], csub[${j}]); + % elif kx == alix[j] and beta == 0: + store_c(&c[i + ${j}*ldc], csub[${j}]); % elif kx == alix[j] and beta == 1: - c[i + ${j}*ldc] += csub[${j}]; + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + csub[${j}]); % elif kx == alix[j]: - c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc]; + store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endfor @@ -47,9 +66,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Handle rows of A which are all zero % for j, jx in enumerate(afix): % if jx == -1 and beta == 0: - c[i + ${j}*ldc] = make_zero(); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - c[i + ${j}*ldc] *= ${beta}; + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor } diff --git a/gimmik/kernels/hip/cstream-ksplit.mako b/gimmik/kernels/hip/cstream-ksplit.mako index bae2d2a..12c59ba 100644 --- a/gimmik/kernels/hip/cstream-ksplit.mako +++ b/gimmik/kernels/hip/cstream-ksplit.mako @@ -4,6 +4,7 @@ kparts = partition(A, ksplit, by='cols') cchunks = chunk(range(m), csz) loaded = set() +preload = context.get('preload', False) %> __global__ __launch_bounds__(${blockx*ksplit}) void @@ -13,7 +14,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -43,14 +44,31 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> % endif % endfor - % if (dotex := dot(lambda kx: f'bv[{kx}]', A[j, kbx])) != '0.0': + <% + nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] + has_dotp = A[j].any() + if nzixs: + first_l_idx, first_kx = nzixs[0] + dotex = f"{A[j, first_kx]}*bv[{first_l_idx}]" + for l_idx, kx in nzixs[1:]: + dotex = f"{dotex} + {A[j, kx]}*bv[{l_idx}]" + else: + dotex = 'make_zero()' + %> dotp = ${dotex}; - % else: - dotp = make_zero(); - % endif ## Save to a register % if loop.index % ksplit == bid: + % if preload and beta == 0: cv[${loop.index // ksplit}] = dotp; + % elif preload and beta == 1 and has_dotp: + cv[${loop.index // ksplit}] = load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif preload and has_dotp: + cv[${loop.index // ksplit}] = ${beta}*load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif not preload: + cv[${loop.index // ksplit}] = dotp; + % endif ## Save to shared memory % else: csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; @@ -66,14 +84,32 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Sum and output the final set of dot products % for j in cchunk: % if loop.index % ksplit == bid: - dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' - for i in range(ksplit - 1))}; - % if beta == 0: - c[i + ${j}*ldc] = dotp; + <% has_dotp = A[j].any() %> + <% + sum_expr = f"cv[{loop.index // ksplit}]" + for s_idx in range(ksplit - 1): + sum_expr = f"{sum_expr} + csub[{s_idx}][{loop.index}][threadIdx.x]" + %> + % if preload and beta == 0: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta == 1 and has_dotp: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta != 1 and has_dotp: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta != 1: + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); + % elif beta == 0: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: - c[i + ${j}*ldc] += dotp; + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: - c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream.mako b/gimmik/kernels/hip/cstream.mako index f75301d..2ee9574 100644 --- a/gimmik/kernels/hip/cstream.mako +++ b/gimmik/kernels/hip/cstream.mako @@ -1,15 +1,17 @@ <%inherit file='base'/> -<% ksplit = 2 if m < 36 else 1 %> +<% +preload = context.get('preload', False) +%> -__global__ __launch_bounds__(128) void +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -26,17 +28,39 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) if (i < n) { % for j, jx in enumerate(A): - % if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': + <% + nzixs = [kx for kx, val in enumerate(jx) if val != 0] + if nzixs: + first_kx = nzixs[0] + dotex = f"{jx[first_kx]}*b[i + {first_kx}*ldb]" + for kx in nzixs[1:]: + dotex = f"{dotex} + {jx[kx]}*b[i + {kx}*ldb]" + else: + dotex = 'make_zero()' + %> dotp = ${dotex}; + % if preload and nzixs: + % if beta == 0: + store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1: + dotp = load_c(&c[i + ${j}*ldc]) + dotp; + store_c(&c[i + ${j}*ldc], dotp); + % else: + dotp = ${beta}*load_c(&c[i + ${j}*ldc]) + dotp; + store_c(&c[i + ${j}*ldc], dotp); + % endif + % elif preload: + % if beta == 0: + store_c(&c[i + ${j}*ldc], make_zero()); + % elif beta != 1: + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); + % endif + % elif beta == 0: + store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1 and nzixs: + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: - dotp = make_zero(); - % endif - % if beta == 0: - c[i + ${j}*ldc] = dotp; - % elif beta == 1 and dotex != '0.0': - c[i + ${j}*ldc] += dotp; - % else: - c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor } diff --git a/gimmik/kernels/hip/mfma-dense-msplit.mako b/gimmik/kernels/hip/mfma-dense-msplit.mako new file mode 100644 index 0000000..6a8971e --- /dev/null +++ b/gimmik/kernels/hip/mfma-dense-msplit.mako @@ -0,0 +1,147 @@ +<%inherit file='base'/> +## +## Dense double-precision GEMM on the CDNA Matrix Cores (MFMA). +## +## A is densified, padded and baked into the kernel in Matrix-Core fragment +## order; B is staged through LDS; C is non-temporal stored; the epilogue is +## fully unrolled. This m-split + k-blocked path uses msplit wavefronts +## (block.y), each owning a slice of the m-tiles. B is staged into LDS in +## chunks of kc active k-tiles so LDS usage is bounded for any k, and only the +## k-rows A uses (bix) are read. +## +## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), g=lane/16, p=lane%16: +## A (16x4 ): A[i][kk] i=p, kk=g +## B (4x16 ): B[kk][j] kk=g, j=p +## C/D(16x16): D[i][j] j=p, i=4*reg + g (v4f64) +## Bake: Ag[(mt*k_tiles+kt)*64 + lane] = A_pad[mt*16 + lane%16][kt*4 + lane//16] +## +<% + tiles = blockx // 16 + active_kt = [kt for kt in range(k_tiles) + if any(amask[mt][kt] for mt in range(m_tiles))] + mtpg = -(-m_tiles // msplit) +%> +typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); +typedef ${dtype} gimmik_f64x2 __attribute__((ext_vector_type(2))); + +__device__ static const ${dtype} ${kname}_Ag[${m_tiles * k_tiles * 64}] = { + ${', '.join(a_hex)} +}; + +__global__ __launch_bounds__(${blockx * msplit}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${n}; + const ${'long long' if k * ldb >= 2**31 else 'int'} ldb = ${ldb}; + const ${'long long' if m * ldc >= 2**31 else 'int'} ldc = ${ldc}; +% endif + const int lane = threadIdx.x; + const int g = lane / 16; + const int p = lane % 16; + const int col_base = ${blockx}*blockIdx.x; + +<% + chunks = [active_kt[c:c+kc] for c in range(0, len(active_kt), kc)] + nthreads = blockx * msplit + half = blockx // 2 + mt_guard = (m_tiles % msplit != 0) +%> + __shared__ __align__(16) ${dtype} ${kname}_Bs[${kc * 4 * blockx}]; + const int tid = threadIdx.y*${blockx} + threadIdx.x; + const int wmt = threadIdx.y*${mtpg}; + ${dtype} a, ${', '.join('bv_%d' % t for t in range(tiles))}; +% for j in range(mtpg): +% for t in range(tiles): + gimmik_f64x4 acc_${j}_${t} = {0.0, 0.0, 0.0, 0.0}; +% endfor +% endfor + +% for ci, chunk in enumerate(chunks): +<% + cpos = {kt: a for a, kt in enumerate(chunk)} + bload = [(kr, cpos[kr // 4]*4 + kr % 4) + for kr in bix_rows if (kr // 4) in cpos] + nb = len(bload) + need_zero = nb < len(chunk)*4 +%> +% if ci > 0: + __syncthreads(); +% endif +% if need_zero: + for (int idx = tid; idx < ${len(chunk)*4*blockx}; idx += ${nthreads}) + ${kname}_Bs[idx] = (${dtype})0; + __syncthreads(); +% endif + { + static const int bg[${nb}] = { ${', '.join(str(x) for x, _ in bload)} }; + static const int bl[${nb}] = { ${', '.join(str(x) for _, x in bload)} }; +% if vec2: + for (int idx = tid; idx < ${nb * half}; idx += ${nthreads}) + { + const int r = idx / ${half}; + const int cc = (idx % ${half}) * 2; + const int col = col_base + cc; + if (col + 1 < n) + *(gimmik_f64x2*)&${kname}_Bs[bl[r]*${blockx} + cc] = + *(const gimmik_f64x2*)&b[bg[r]*ldb + col]; + else if (col < n) + ${kname}_Bs[bl[r]*${blockx} + cc] = b[bg[r]*ldb + col]; + } +% else: + for (int idx = tid; idx < ${nb * blockx}; idx += ${nthreads}) + { + const int r = idx / ${blockx}; + const int cc = idx % ${blockx}; + const int col = col_base + cc; + if (col < n) + ${kname}_Bs[bl[r]*${blockx} + cc] = b[bg[r]*ldb + col]; + } +% endif + } + __syncthreads(); +% for a_pos, kt in enumerate(chunk): +% for t in range(tiles): + bv_${t} = ${kname}_Bs[(${a_pos*4} + g)*${blockx} + ${t*16} + p]; +% endfor +% for j in range(mtpg): +% if mt_guard: + if (wmt + ${j} < ${m_tiles}) +% endif + { + a = ${kname}_Ag[((wmt + ${j})*${k_tiles} + ${kt})*64 + lane]; +% for t in range(tiles): + acc_${j}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv_${t}, acc_${j}_${t}, 0, 0, 0); +% endfor + } +% endfor +% endfor +% endfor + +% for j in range(mtpg): +% for t in range(tiles): +% for reg in range(4): +% if mt_guard: + if (wmt + ${j} < ${m_tiles}) +% endif + { + const int row = (wmt + ${j})*16 + ${4*reg} + g; + const int col = col_base + ${t*16} + p; + if (row < ${m} && col < n) +% if beta == 0: + store_c(&c[row*ldc + col], acc_${j}_${t}[${reg}]); +% elif beta == 1: + store_c(&c[row*ldc + col], load_c(&c[row*ldc + col]) + acc_${j}_${t}[${reg}]); +% else: + store_c(&c[row*ldc + col], ${beta}*load_c(&c[row*ldc + col]) + acc_${j}_${t}[${reg}]); +% endif + } +% endfor +% endfor +% endfor +}