From 8e23d63cafcc24d6d5f11bbccc135e4f4ba482d7 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Thu, 18 Jun 2026 00:06:05 -0500 Subject: [PATCH 01/25] Add tuned HIP GiMMiK preload variants --- gimmik/hip.py | 113 ++++++++++++- gimmik/kernels/hip/base.mako | 32 ++++ .../kernels/hip/bstream-msplit-preload-c.mako | 98 +++++++++++ .../hip/bstream-msplit-width-preload-c.mako | 138 +++++++++++++++ gimmik/kernels/hip/bstream-msplit.mako | 10 +- gimmik/kernels/hip/bstream-preload-c.mako | 63 +++++++ .../kernels/hip/bstream-width-preload-c.mako | 103 ++++++++++++ gimmik/kernels/hip/bstream.mako | 12 +- .../kernels/hip/cstream-ksplit-preload-c.mako | 103 ++++++++++++ .../hip/cstream-ksplit-width-preload-c.mako | 157 ++++++++++++++++++ gimmik/kernels/hip/cstream-ksplit.mako | 6 +- gimmik/kernels/hip/cstream-preload-c.mako | 51 ++++++ .../kernels/hip/cstream-width-preload-c.mako | 106 ++++++++++++ gimmik/kernels/hip/cstream.mako | 8 +- 14 files changed, 978 insertions(+), 22 deletions(-) create mode 100644 gimmik/kernels/hip/bstream-msplit-preload-c.mako create mode 100644 gimmik/kernels/hip/bstream-msplit-width-preload-c.mako create mode 100644 gimmik/kernels/hip/bstream-preload-c.mako create mode 100644 gimmik/kernels/hip/bstream-width-preload-c.mako create mode 100644 gimmik/kernels/hip/cstream-ksplit-preload-c.mako create mode 100644 gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako create mode 100644 gimmik/kernels/hip/cstream-preload-c.mako create mode 100644 gimmik/kernels/hip/cstream-width-preload-c.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index a58c8fa..ee38c9f 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -8,23 +8,128 @@ class HIPMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0} def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64): + max_block_threads = 1024 + max_shared = 64*1024 + + def emit(name, args, meta): + block = meta.get('block', self.basemeta['block']) + shared = meta.get('shared', self.basemeta['shared']) + threads = block[0]*block[1]*block[2] + + if threads <= max_block_threads and shared <= max_shared: + yield (name, args, meta) + # B loading, C streaming kernel - yield ('cstream', {}, {}) + yield from emit('cstream', {}, {}) # B streaming, C accumulation kernel - yield ('bstream', {}, {}) + yield from emit('bstream', {}, {}) # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 64 args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} - yield ('bstream-msplit', args, meta) + yield from emit('bstream-msplit', args, meta) # Two-way k-split B loading, C streaming kernel ks, csz, blkx = 2, 24, 64 args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} - yield ('cstream-ksplit', args, meta) + yield from emit('cstream-ksplit', args, meta) + + # Tuned HIP variants + msplits, ksplits = [4, 8], [2, 4] + bsz, csz, blkx = 8, 8, 64 + width = 2 if self.aligne is not None and self.aligne % 2 == 0 else 1 + + # B loading, C streaming kernel + args = {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'desc': f'cstream/x{blkx}'} + yield from emit('cstream', args, meta) + + # B streaming, C accumulation kernel + meta = {'block': (blkx, 1, 1), 'desc': f'bstream/x{blkx}'} + yield from emit('bstream', args, meta) + + for ms in msplits: + # m-split B streaming, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + shared = 2*bsz*blkx*dsize + meta = {'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + yield from emit('bstream-msplit', args, meta) + + for ks in ksplits: + # k-split B loading, C streaming kernel + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} + shared = (ks - 1)*csz*blkx*dsize + meta = {'block': (blkx, ks, 1), 'shared': shared, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + yield from emit('cstream-ksplit', args, meta) + + # B loading, C preloading, C streaming kernel + args = {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'desc': f'cstream-preload-c/x{blkx}'} + yield from emit('cstream-preload-c', args, meta) + + # B streaming, C preloading, C accumulation kernel + meta = {'block': (blkx, 1, 1), 'desc': f'bstream-preload-c/x{blkx}'} + yield from emit('bstream-preload-c', args, meta) + + if width > 1: + args = {'dtype': f'{dtype}{width}', 'width': width, + 'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'width': width, + 'desc': f'cstream-width-preload-c/w{width}-x{blkx}'} + yield from emit('cstream-width-preload-c', args, meta) + + meta = {'block': (blkx, 1, 1), 'width': width, + 'desc': f'bstream-width-preload-c/w{width}-x{blkx}'} + yield from emit('bstream-width-preload-c', args, meta) + + for ms in msplits: + # m-split B streaming, C preloading, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + shared = 2*bsz*blkx*dsize + meta = {'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'bstream-msplit-preload-c/m{ms}-b{bsz}-x{blkx}'} + yield from emit('bstream-msplit-preload-c', args, meta) + + if width > 1: + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx, + 'dtype': f'{dtype}{width}', 'width': width} + meta = { + 'block': (blkx, ms, 1), 'shared': shared*width, + 'width': width, + 'desc': ( + f'bstream-msplit-width-preload-c/w{width}-' + f'm{ms}-b{bsz}-x{blkx}' + ) + } + yield from emit('bstream-msplit-width-preload-c', args, meta) + + for ks in ksplits: + # k-split B loading, C preloading, C streaming kernel + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} + shared = (ks - 1)*csz*blkx*dsize + meta = { + 'block': (blkx, ks, 1), 'shared': shared, + 'desc': f'cstream-ksplit-preload-c/k{ks}-c{csz}-x{blkx}' + } + yield from emit('cstream-ksplit-preload-c', args, meta) + + if width > 1: + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx, + 'dtype': f'{dtype}{width}', 'width': width} + meta = { + 'block': (blkx, ks, 1), 'shared': shared*width, + 'width': width, + 'desc': ( + f'cstream-ksplit-width-preload-c/w{width}-' + f'k{ks}-c{csz}-x{blkx}' + ) + } + yield from emit('cstream-ksplit-width-preload-c', args, meta) def _process_meta(self, meta): if self.n is not None: diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index 874fbbd..b40a0b9 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -9,4 +9,36 @@ static inline __device__ ${dtype} make_zero() { return 0; } % endif +static inline __device__ void +nt_store_c(${dtype}* p, ${dtype} v) +{ +% if dtype.endswith('4'): + __builtin_nontemporal_store(v.x, &p->x); + __builtin_nontemporal_store(v.y, &p->y); + __builtin_nontemporal_store(v.z, &p->z); + __builtin_nontemporal_store(v.w, &p->w); +% elif dtype.endswith('2'): + __builtin_nontemporal_store(v.x, &p->x); + __builtin_nontemporal_store(v.y, &p->y); +% else: + __builtin_nontemporal_store(v, p); +% endif +} + +static inline __device__ ${dtype} +nt_load_c(const ${dtype}* p) +{ +% if dtype.endswith('4'): + return make_${dtype}(__builtin_nontemporal_load(&p->x), + __builtin_nontemporal_load(&p->y), + __builtin_nontemporal_load(&p->z), + __builtin_nontemporal_load(&p->w)); +% elif dtype.endswith('2'): + return make_${dtype}(__builtin_nontemporal_load(&p->x), + __builtin_nontemporal_load(&p->y)); +% else: + return __builtin_nontemporal_load(p); +% endif +} + ${next.body()} diff --git a/gimmik/kernels/hip/bstream-msplit-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-preload-c.mako new file mode 100644 index 0000000..8b6f008 --- /dev/null +++ b/gimmik/kernels/hip/bstream-msplit-preload-c.mako @@ -0,0 +1,98 @@ +<%inherit file='base'/> + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +%> + +__global__ __launch_bounds__(${blockx*msplit}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = ((n + ${width} - 1) / ${width}) * ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + int i = blockDim.x*blockIdx.x + threadIdx.x; + + ${dtype} bv, csub[${-(-m // msplit)}]; + __shared__ ${dtype} bsub[2][${bsz}][${blockx}]; + +## Fill the initial shared memory block +% for cid in range(msplit): + if (i < n && threadIdx.y == ${cid}) + { + % for kx in bchunks[0]: + % if loop.index % msplit == cid: + bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + % endif + % endfor + + ## Preload C values for active rows owned by this m-split lane + % for j, jx in enumerate(mx[cid]): + % if afix[jx] != -1: + % if beta == 0: + csub[${j}] = make_zero(); + % elif beta == 1: + csub[${j}] = nt_load_c(&c[i + ${jx}*ldc]); + % else: + csub[${j}] = ${beta}*nt_load_c(&c[i + ${jx}*ldc]); + % endif + % endif + % endfor + } +% endfor + __syncthreads(); + +## Iterate over each row-chunk of B +% for bb in range(len(bchunks)): + ## Iterate over each row-chunk of C + % for cid, mcx in enumerate(mx): + if (i < n && threadIdx.y == ${cid}) + { + ## Start filling the next shared memory block + % if not loop.parent.last: + % for kx in bchunks[bb + 1]: + % if loop.index % msplit == cid: + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + % endif + % endfor + % endif + ## Accumulate our dot products + % for kx in bchunks[bb]: + bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; + % for j, jx in enumerate(A[mcx, kx]): + % if jx != 0: + csub[${j}] += ${jx}*bv; + % endif + ## If we're done with this dot product then store to global + % if kx == alix[mcx[j]]: + nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + % endif + % endfor + % endfor + ## Handle rows of A which are all zero + % if loop.parent.last: + % for j, jx in enumerate(afix): + % if jx == -1 and j % msplit == cid and beta == 0: + nt_store_c(&c[i + ${j}*ldc], make_zero()); + % elif jx == -1 and j % msplit == cid and beta != 1: + nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + % endif + % endfor + % endif + } + % endfor + __syncthreads(); +% endfor +} diff --git a/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako new file mode 100644 index 0000000..9659db7 --- /dev/null +++ b/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako @@ -0,0 +1,138 @@ +<%inherit file='base'/> + +% if width == 2: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); +} +% elif width == 4: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); +} +% else: +#error "bstream_msplit_width_preload_c only supports width=2 or width=4" +% endif + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +%> + +__global__ __launch_bounds__(${blockx*msplit}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = (n + ${width} - 1) / ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + int i = blockDim.x*blockIdx.x + threadIdx.x; + + ${dtype} bv, csub[${-(-m // msplit)}]; + __shared__ ${dtype} bsub[2][${bsz}][${blockx}]; + +## Fill the initial shared memory block +% for cid in range(msplit): + if (i < n && threadIdx.y == ${cid}) + { + % for kx in bchunks[0]: + % if loop.index % msplit == cid: + bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + % endif + % endfor + + ## Preload C values for active rows owned by this m-split lane + % for j, jx in enumerate(mx[cid]): + % if afix[jx] != -1: + % if beta == 0: + csub[${j}] = make_zero(); + % elif beta == 1: + csub[${j}] = nt_load_c(&c[i + ${jx}*ldc]); + % else: + csub[${j}] = gimmik_vmul(${beta}, nt_load_c(&c[i + ${jx}*ldc])); + % endif + % endif + % endfor + } +% endfor + __syncthreads(); + +## Iterate over each row-chunk of B +% for bb in range(len(bchunks)): + ## Iterate over each row-chunk of C + % for cid, mcx in enumerate(mx): + if (i < n && threadIdx.y == ${cid}) + { + ## Start filling the next shared memory block + % if not loop.parent.last: + % for kx in bchunks[bb + 1]: + % if loop.index % msplit == cid: + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + % endif + % endfor + % endif + ## Accumulate our dot products + % for kx in bchunks[bb]: + bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; + % for j, jx in enumerate(A[mcx, kx]): + % if jx != 0: + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + % endif + ## If we're done with this dot product then store to global + % if kx == alix[mcx[j]]: + nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + % endif + % endfor + % endfor + ## Handle rows of A which are all zero + % if loop.parent.last: + % for j, jx in enumerate(afix): + % if jx == -1 and j % msplit == cid and beta == 0: + nt_store_c(&c[i + ${j}*ldc], make_zero()); + % elif jx == -1 and j % msplit == cid and beta != 1: + nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); + % endif + % endfor + % endif + } + % endfor + __syncthreads(); +% endfor +} diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 6359ca1..6470477 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -66,11 +66,11 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endif ## If we're done with this dot product then store to global % if kx == alix[mcx[j]] and beta == 0: - c[i + ${mcx[j]}*ldc] = csub[${j}]; + nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % elif kx == alix[mcx[j]] and beta == 1: - c[i + ${mcx[j]}*ldc] += csub[${j}]; + nt_store_c(&c[i + ${mcx[j]}*ldc], nt_load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]); % elif kx == alix[mcx[j]]: - c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc]; + nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*nt_load_c(&c[i + ${mcx[j]}*ldc])); % endif % endfor % endfor @@ -78,9 +78,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if loop.parent.last: % for j, jx in enumerate(afix): % if jx == -1 and j % msplit == cid and beta == 0: - c[i + ${j}*ldc] = make_zero(); + nt_store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - c[i + ${j}*ldc] *= ${beta}; + nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc])*${beta}); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream-preload-c.mako b/gimmik/kernels/hip/bstream-preload-c.mako new file mode 100644 index 0000000..30b08f6 --- /dev/null +++ b/gimmik/kernels/hip/bstream-preload-c.mako @@ -0,0 +1,63 @@ +<%inherit file='base'/> + +__global__ __launch_bounds__(${blockx}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = ((n + ${width} - 1) / ${width}) * ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i < n) + { + ${dtype} bv, csub[${m}]; + +## Preload C values for rows which will receive a non-zero dot product +% for j, jx in enumerate(afix): + % if jx != -1: + % if beta == 0: + csub[${j}] = make_zero(); + % elif beta == 1: + csub[${j}] = nt_load_c(&c[i + ${j}*ldc]); + % else: + csub[${j}] = ${beta}*nt_load_c(&c[i + ${j}*ldc]); + % endif + % endif +% endfor + +## Iterate through the used rows of B +% for kx in bix: + bv = b[i + ${kx}*ldb]; + % for j, jx in enumerate(A[:, kx]): + % if jx != 0: + csub[${j}] += ${jx}*bv; + % endif + ## + % if kx == alix[j]: + nt_store_c(&c[i + ${j}*ldc], csub[${j}]); + % endif + % endfor +% endfor + +## Handle rows of A which are all zero +% for j, jx in enumerate(afix): + % if jx == -1 and beta == 0: + nt_store_c(&c[i + ${j}*ldc], make_zero()); + % elif jx == -1 and beta != 1: + nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + % endif +% endfor + } +} diff --git a/gimmik/kernels/hip/bstream-width-preload-c.mako b/gimmik/kernels/hip/bstream-width-preload-c.mako new file mode 100644 index 0000000..97a7571 --- /dev/null +++ b/gimmik/kernels/hip/bstream-width-preload-c.mako @@ -0,0 +1,103 @@ +<%inherit file='base'/> + +% if width == 2: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); +} +% elif width == 4: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); +} +% else: +#error "bstream_width_preload_c only supports width=2 or width=4" +% endif + +__global__ __launch_bounds__(${blockx}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = (n + ${width} - 1) / ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i < n) + { + ${dtype} bv, csub[${m}]; + +## Preload C values for rows which will receive a non-zero dot product +% for j, jx in enumerate(afix): + % if jx != -1: + % if beta == 0: + csub[${j}] = make_zero(); + % elif beta == 1: + csub[${j}] = nt_load_c(&c[i + ${j}*ldc]); + % else: + csub[${j}] = gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc])); + % endif + % endif +% endfor + +## Iterate through the used rows of B +% for kx in bix: + bv = b[i + ${kx}*ldb]; + % for j, jx in enumerate(A[:, kx]): + % if jx != 0: + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + % endif + ## + % if kx == alix[j]: + nt_store_c(&c[i + ${j}*ldc], csub[${j}]); + % endif + % endfor +% endfor + +## Handle rows of A which are all zero +% for j, jx in enumerate(afix): + % if jx == -1 and beta == 0: + nt_store_c(&c[i + ${j}*ldc], make_zero()); + % elif jx == -1 and beta != 1: + nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); + % endif +% endfor + } +} diff --git a/gimmik/kernels/hip/bstream.mako b/gimmik/kernels/hip/bstream.mako index 2f6dc62..9634c73 100644 --- a/gimmik/kernels/hip/bstream.mako +++ b/gimmik/kernels/hip/bstream.mako @@ -1,6 +1,6 @@ <%inherit file='base'/> -__global__ __launch_bounds__(128) void +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, @@ -35,11 +35,11 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endif ## % if kx == alix[j] and beta == 0: - c[i + ${j}*ldc] = csub[${j}]; + nt_store_c(&c[i + ${j}*ldc], csub[${j}]); % elif kx == alix[j] and beta == 1: - c[i + ${j}*ldc] += csub[${j}]; + nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc]) + csub[${j}]); % elif kx == alix[j]: - c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc]; + nt_store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*nt_load_c(&c[i + ${j}*ldc])); % endif % endfor % endfor @@ -47,9 +47,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Handle rows of A which are all zero % for j, jx in enumerate(afix): % if jx == -1 and beta == 0: - c[i + ${j}*ldc] = make_zero(); + nt_store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - c[i + ${j}*ldc] *= ${beta}; + nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc])*${beta}); % endif % endfor } diff --git a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako new file mode 100644 index 0000000..507c34f --- /dev/null +++ b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako @@ -0,0 +1,103 @@ +<%inherit file='base'/> + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(range(m), csz) +loaded = set() +%> + +__global__ __launch_bounds__(${blockx*ksplit}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = ((n + ${width} - 1) / ${width}) * ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + int i = blockDim.x*blockIdx.x + threadIdx.x; + + ${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp; + __shared__ ${dtype} csub[${ksplit - 1}][${csz}][${blockx}]; + +## Iterate over the row-partitions of C +% for cchunk in cchunks: + ## Iterate over the row-partitions of B + % for bid, kbx in enumerate(kparts): + if (i < n && threadIdx.y == ${bid}) + { + ## Evaluate our partial dot products + % for j in cchunk: + ## Load in any missing parts of B + % for kx in kbx: + % if A[j, kx] != 0 and kx not in loaded: + bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> + % endif + % endfor + <% + dotex = dot(lambda kx: f'bv[{kx}]', A[j, kbx]) + has_dotp = any(A[j, kx] != 0 for kx in range(k)) + %> + % if dotex != '0.0': + dotp = ${dotex}; + % else: + dotp = make_zero(); + % endif + ## Save to a register + % if loop.index % ksplit == bid: + % if beta == 0: + cv[${loop.index // ksplit}] = dotp; + % elif beta == 1 and has_dotp: + cv[${loop.index // ksplit}] = nt_load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif has_dotp: + cv[${loop.index // ksplit}] = ${beta}*nt_load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % endif + ## Save to shared memory + % else: + csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; + % endif + % endfor + } + % endfor + __syncthreads(); + ## Iterate over the column-partitions of B + % for bid, kbx in enumerate(kparts): + if (i < n && threadIdx.y == ${bid}) + { + ## Sum and output the final set of dot products + % for j in cchunk: + % if loop.index % ksplit == bid: + <% has_dotp = any(A[j, kx] != 0 for kx in range(k)) %> + % if beta == 0: + dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' + for i in range(ksplit - 1))}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1 and has_dotp: + dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' + for i in range(ksplit - 1))}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta != 1 and has_dotp: + dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' + for i in range(ksplit - 1))}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta != 1: + nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + % endif + % endif + % endfor + } + % endfor + __syncthreads(); +% endfor +} diff --git a/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako new file mode 100644 index 0000000..b435913 --- /dev/null +++ b/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako @@ -0,0 +1,157 @@ +<%inherit file='base'/> + +% if width == 2: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); +} +% elif width == 4: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); +} +% else: +#error "cstream_ksplit_width_preload_c only supports width=2 or width=4" +% endif + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(range(m), csz) +loaded = set() +%> + +__global__ __launch_bounds__(${blockx*ksplit}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = (n + ${width} - 1) / ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + int i = blockDim.x*blockIdx.x + threadIdx.x; + + ${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp; + __shared__ ${dtype} csub[${ksplit - 1}][${csz}][${blockx}]; + +## Iterate over the row-partitions of C +% for cchunk in cchunks: + ## Iterate over the column-partitions of B + % for bid, kbx in enumerate(kparts): + if (i < n && threadIdx.y == ${bid}) + { + ## Evaluate our partial dot products + % for j in cchunk: + ## Load in any missing parts of B + % for kx in kbx: + % if A[j, kx] != 0 and kx not in loaded: + bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> + % endif + % endfor + + ## Expand vectorized partial dot product + <% + nzixs = [] + for l_idx, kx in enumerate(kbx): + if A[j, kx] != 0: + nzixs.append((l_idx, kx)) + + has_dotp = any(A[j, kx] != 0 for kx in range(k)) + if not nzixs: + dotex = 'make_zero()' + else: + first_l_idx, first_kx = nzixs[0] + dotex = f"gimmik_vmul({A[j, first_kx]}, bv[{first_l_idx}])" + for l_idx, kx in nzixs[1:]: + dotex = f"gimmik_vmadd({dotex}, {A[j, kx]}, bv[{l_idx}])" + %> + dotp = ${dotex}; + + ## Save to a register + % if loop.index % ksplit == bid: + % if beta == 0: + cv[${loop.index // ksplit}] = dotp; + % elif beta == 1 and has_dotp: + cv[${loop.index // ksplit}] = nt_load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); + % elif has_dotp: + cv[${loop.index // ksplit}] = gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc])); + cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); + % endif + ## Save to shared memory + % else: + csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; + % endif + % endfor + } + % endfor + __syncthreads(); + + ## Sum and output the final set of dot products + % for bid, kbx in enumerate(kparts): + if (i < n && threadIdx.y == ${bid}) + { + % for j in cchunk: + % if loop.index % ksplit == bid: + <% + has_dotp = any(A[j, kx] != 0 for kx in range(k)) + sum_expr = f"cv[{loop.index // ksplit}]" + for s_idx in range(ksplit - 1): + sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" + %> + % if beta == 0: + dotp = ${sum_expr}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1 and has_dotp: + dotp = ${sum_expr}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta != 1 and has_dotp: + dotp = ${sum_expr}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta != 1: + nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); + % endif + % endif + % endfor + } + % endfor + __syncthreads(); +% endfor +} diff --git a/gimmik/kernels/hip/cstream-ksplit.mako b/gimmik/kernels/hip/cstream-ksplit.mako index bae2d2a..6fd3210 100644 --- a/gimmik/kernels/hip/cstream-ksplit.mako +++ b/gimmik/kernels/hip/cstream-ksplit.mako @@ -69,11 +69,11 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' for i in range(ksplit - 1))}; % if beta == 0: - c[i + ${j}*ldc] = dotp; + nt_store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: - c[i + ${j}*ldc] += dotp; + nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc]) + dotp); % else: - c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + nt_store_c(&c[i + ${j}*ldc], dotp + ${beta}*nt_load_c(&c[i + ${j}*ldc])); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream-preload-c.mako b/gimmik/kernels/hip/cstream-preload-c.mako new file mode 100644 index 0000000..041e674 --- /dev/null +++ b/gimmik/kernels/hip/cstream-preload-c.mako @@ -0,0 +1,51 @@ +<%inherit file='base'/> + +<% ksplit = 2 if m < 36 else 1 %> + +__global__ __launch_bounds__(128) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = ((n + ${width} - 1) / ${width}) * ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + const int i = blockDim.x*blockIdx.x + threadIdx.x; + ${dtype} dotp; + + if (i < n) + { +% for j, jx in enumerate(A): + % if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': + % if beta == 0: + dotp = ${dotex}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1: + dotp = nt_load_c(&c[i + ${j}*ldc]); + dotp += ${dotex}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % else: + dotp = ${beta}*nt_load_c(&c[i + ${j}*ldc]); + dotp += ${dotex}; + nt_store_c(&c[i + ${j}*ldc], dotp); + % endif + % else: + % if beta == 0: + nt_store_c(&c[i + ${j}*ldc], make_zero()); + % elif beta != 1: + nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + % endif + % endif +% endfor + } +} diff --git a/gimmik/kernels/hip/cstream-width-preload-c.mako b/gimmik/kernels/hip/cstream-width-preload-c.mako new file mode 100644 index 0000000..9f2f57c --- /dev/null +++ b/gimmik/kernels/hip/cstream-width-preload-c.mako @@ -0,0 +1,106 @@ +<%inherit file='base'/> + +% if width == 2: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); +} +% elif width == 4: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); +} +% else: +#error "cstream_width_preload_c only supports width=2 or width=4" +% endif + +__global__ __launch_bounds__(${blockx}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ + % if width > 1: + n = (n + ${width} - 1) / ${width}; + ldb /= ${width}; + ldc /= ${width}; + % endif +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${-(-n // width)}; + const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; + const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; +% endif + const int i = blockDim.x*blockIdx.x + threadIdx.x; + ${dtype} bv, dotp; + + if (i < n) + { +% for j, row in enumerate(A): + <% + nzixs = [kx for kx, val in enumerate(row) if val != 0] + %> + % if nzixs: + % if beta == 0: + <% first_kx = nzixs[0] %> + bv = b[i + ${first_kx}*ldb]; + dotp = gimmik_vmul(${row[first_kx]}, bv); + % for kx in nzixs[1:]: + bv = b[i + ${kx}*ldb]; + dotp = gimmik_vmadd(dotp, ${row[kx]}, bv); + % endfor + nt_store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1: + dotp = nt_load_c(&c[i + ${j}*ldc]); + % for kx in nzixs: + bv = b[i + ${kx}*ldb]; + dotp = gimmik_vmadd(dotp, ${row[kx]}, bv); + % endfor + nt_store_c(&c[i + ${j}*ldc], dotp); + % else: + dotp = gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc])); + % for kx in nzixs: + bv = b[i + ${kx}*ldb]; + dotp = gimmik_vmadd(dotp, ${row[kx]}, bv); + % endfor + nt_store_c(&c[i + ${j}*ldc], dotp); + % endif + % else: + % if beta == 0: + nt_store_c(&c[i + ${j}*ldc], make_zero()); + % elif beta != 1: + nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); + % endif + % endif +% endfor + } +} diff --git a/gimmik/kernels/hip/cstream.mako b/gimmik/kernels/hip/cstream.mako index f75301d..0651e87 100644 --- a/gimmik/kernels/hip/cstream.mako +++ b/gimmik/kernels/hip/cstream.mako @@ -2,7 +2,7 @@ <% ksplit = 2 if m < 36 else 1 %> -__global__ __launch_bounds__(128) void +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, @@ -32,11 +32,11 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) dotp = make_zero(); % endif % if beta == 0: - c[i + ${j}*ldc] = dotp; + nt_store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1 and dotex != '0.0': - c[i + ${j}*ldc] += dotp; + nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc]) + dotp); % else: - c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + nt_store_c(&c[i + ${j}*ldc], dotp + ${beta}*nt_load_c(&c[i + ${j}*ldc])); % endif % endfor } From 8f4d03ee5b2c6e00a1bb7872bca64f0fae03d9d4 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Thu, 18 Jun 2026 00:12:02 -0500 Subject: [PATCH 02/25] Fix HIP GiMMiK block size metadata --- gimmik/hip.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index ee38c9f..920cc0b 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -19,11 +19,13 @@ def emit(name, args, meta): if threads <= max_block_threads and shared <= max_shared: yield (name, args, meta) + blkx = self.basemeta['block'][0] + # B loading, C streaming kernel - yield from emit('cstream', {}, {}) + yield from emit('cstream', {'blockx': blkx}, {}) # B streaming, C accumulation kernel - yield from emit('bstream', {}, {}) + yield from emit('bstream', {'blockx': blkx}, {}) # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 64 From 96671a6364f8ccbac0d14d89364cda13f5223952 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Sun, 21 Jun 2026 23:18:07 -0500 Subject: [PATCH 03/25] Address HIP GiMMiK review comments --- gimmik/hip.py | 4 ++ .../hip/bstream-msplit-width-preload-c.mako | 40 +-------------- .../kernels/hip/bstream-width-preload-c.mako | 40 +-------------- .../kernels/hip/cstream-ksplit-preload-c.mako | 4 +- .../hip/cstream-ksplit-width-preload-c.mako | 51 ++----------------- .../kernels/hip/cstream-width-preload-c.mako | 42 +-------------- gimmik/kernels/hip/vector.mako | 41 +++++++++++++++ 7 files changed, 54 insertions(+), 168 deletions(-) create mode 100644 gimmik/kernels/hip/vector.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index 920cc0b..cea1be5 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -39,6 +39,10 @@ def emit(name, args, meta): meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} yield from emit('cstream-ksplit', args, meta) + # Only emit tuned variants on the architecture they were tuned for. + if gcn_arch != 'gfx942' or warp_size != 64: + return + # Tuned HIP variants msplits, ksplits = [4, 8], [2, 4] bsz, csz, blkx = 8, 8, 64 diff --git a/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako index 9659db7..4466f9f 100644 --- a/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako +++ b/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako @@ -1,44 +1,6 @@ <%inherit file='base'/> -% if width == 2: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); -} -% elif width == 4: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); -} -% else: -#error "bstream_msplit_width_preload_c only supports width=2 or width=4" -% endif +<%include file='vector'/> <% mx = partition(A, into=msplit, by='rows') diff --git a/gimmik/kernels/hip/bstream-width-preload-c.mako b/gimmik/kernels/hip/bstream-width-preload-c.mako index 97a7571..2c4e5c5 100644 --- a/gimmik/kernels/hip/bstream-width-preload-c.mako +++ b/gimmik/kernels/hip/bstream-width-preload-c.mako @@ -1,44 +1,6 @@ <%inherit file='base'/> -% if width == 2: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); -} -% elif width == 4: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); -} -% else: -#error "bstream_width_preload_c only supports width=2 or width=4" -% endif +<%include file='vector'/> __global__ __launch_bounds__(${blockx}) void % if n is None: diff --git a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako index 507c34f..51f1db4 100644 --- a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako +++ b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako @@ -45,7 +45,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endfor <% dotex = dot(lambda kx: f'bv[{kx}]', A[j, kbx]) - has_dotp = any(A[j, kx] != 0 for kx in range(k)) + has_dotp = A[j].any() %> % if dotex != '0.0': dotp = ${dotex}; @@ -78,7 +78,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Sum and output the final set of dot products % for j in cchunk: % if loop.index % ksplit == bid: - <% has_dotp = any(A[j, kx] != 0 for kx in range(k)) %> + <% has_dotp = A[j].any() %> % if beta == 0: dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' for i in range(ksplit - 1))}; diff --git a/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako index b435913..bdac6dc 100644 --- a/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako +++ b/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako @@ -1,46 +1,6 @@ <%inherit file='base'/> -% if width == 2: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); -} -% elif width == 4: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); -} -% else: -#error "cstream_ksplit_width_preload_c only supports width=2 or width=4" -% endif +<%include file='vector'/> <% kparts = partition(A, ksplit, by='cols') @@ -88,12 +48,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Expand vectorized partial dot product <% - nzixs = [] - for l_idx, kx in enumerate(kbx): - if A[j, kx] != 0: - nzixs.append((l_idx, kx)) + nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] - has_dotp = any(A[j, kx] != 0 for kx in range(k)) + has_dotp = A[j].any() if not nzixs: dotex = 'make_zero()' else: @@ -131,7 +88,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for j in cchunk: % if loop.index % ksplit == bid: <% - has_dotp = any(A[j, kx] != 0 for kx in range(k)) + has_dotp = A[j].any() sum_expr = f"cv[{loop.index // ksplit}]" for s_idx in range(ksplit - 1): sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" diff --git a/gimmik/kernels/hip/cstream-width-preload-c.mako b/gimmik/kernels/hip/cstream-width-preload-c.mako index 9f2f57c..86acfcb 100644 --- a/gimmik/kernels/hip/cstream-width-preload-c.mako +++ b/gimmik/kernels/hip/cstream-width-preload-c.mako @@ -1,46 +1,6 @@ <%inherit file='base'/> -% if width == 2: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); -} -% elif width == 4: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); -} -% else: -#error "cstream_width_preload_c only supports width=2 or width=4" -% endif +<%include file='vector'/> __global__ __launch_bounds__(${blockx}) void % if n is None: diff --git a/gimmik/kernels/hip/vector.mako b/gimmik/kernels/hip/vector.mako new file mode 100644 index 0000000..268d6ab --- /dev/null +++ b/gimmik/kernels/hip/vector.mako @@ -0,0 +1,41 @@ +% if width == 2: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); +} +% elif width == 4: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); +} +% else: +#error "HIP vector helpers only support width=2 or width=4" +% endif From 739a82e18f508597081da9c73fb6ed780e7817ed Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Mon, 22 Jun 2026 01:45:23 -0500 Subject: [PATCH 04/25] Handle ROCm feature suffixes for gfx942 tuning --- gimmik/hip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index cea1be5..a31c3d6 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -40,7 +40,8 @@ def emit(name, args, meta): yield from emit('cstream-ksplit', args, meta) # Only emit tuned variants on the architecture they were tuned for. - if gcn_arch != 'gfx942' or warp_size != 64: + base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None + if base_arch != 'gfx942' or warp_size != 64: return # Tuned HIP variants From 0633539fbe6c99c97fdbf7aa91c0cb8df35e5160 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Mon, 22 Jun 2026 03:30:28 -0500 Subject: [PATCH 05/25] Enable tuned HIP variants on gfx90a --- gimmik/hip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index a31c3d6..e32152d 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -39,9 +39,9 @@ def emit(name, args, meta): meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} yield from emit('cstream-ksplit', args, meta) - # Only emit tuned variants on the architecture they were tuned for. + # Only emit tuned variants on architectures they have been validated for. base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None - if base_arch != 'gfx942' or warp_size != 64: + if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64: return # Tuned HIP variants From 7b59fb0a791cbf1324a46c5b0451bfc63a8b6140 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Tue, 23 Jun 2026 04:57:56 -0500 Subject: [PATCH 06/25] Parameterize HIP vector width and refine preload kernels --- gimmik/hip.py | 155 ++++++++---------- gimmik/kernels/hip/base.mako | 83 ++++++++++ .../kernels/hip/bstream-msplit-preload-c.mako | 28 ++-- .../hip/bstream-msplit-width-preload-c.mako | 100 ----------- gimmik/kernels/hip/bstream-msplit.mako | 16 +- gimmik/kernels/hip/bstream-preload-c.mako | 28 ++-- .../kernels/hip/bstream-width-preload-c.mako | 65 -------- gimmik/kernels/hip/bstream.mako | 16 +- .../kernels/hip/cstream-ksplit-preload-c.mako | 45 ++--- .../hip/cstream-ksplit-width-preload-c.mako | 114 ------------- gimmik/kernels/hip/cstream-ksplit.mako | 30 ++-- gimmik/kernels/hip/cstream-preload-c.mako | 32 ++-- .../kernels/hip/cstream-width-preload-c.mako | 66 -------- gimmik/kernels/hip/cstream.mako | 23 ++- gimmik/kernels/hip/vector.mako | 41 ----- 15 files changed, 286 insertions(+), 556 deletions(-) delete mode 100644 gimmik/kernels/hip/bstream-msplit-width-preload-c.mako delete mode 100644 gimmik/kernels/hip/bstream-width-preload-c.mako delete mode 100644 gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako delete mode 100644 gimmik/kernels/hip/cstream-width-preload-c.mako delete mode 100644 gimmik/kernels/hip/vector.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index e32152d..57fa394 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -47,96 +47,83 @@ def emit(name, args, meta): # Tuned HIP variants msplits, ksplits = [4, 8], [2, 4] bsz, csz, blkx = 8, 8, 64 - width = 2 if self.aligne is not None and self.aligne % 2 == 0 else 1 - - # B loading, C streaming kernel - args = {'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'desc': f'cstream/x{blkx}'} - yield from emit('cstream', args, meta) - - # B streaming, C accumulation kernel - meta = {'block': (blkx, 1, 1), 'desc': f'bstream/x{blkx}'} - yield from emit('bstream', args, meta) - - for ms in msplits: - # m-split B streaming, C accumulation kernel - args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - shared = 2*bsz*blkx*dsize - meta = {'block': (blkx, ms, 1), 'shared': shared, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} - yield from emit('bstream-msplit', args, meta) - - for ks in ksplits: - # k-split B loading, C streaming kernel - args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} - shared = (ks - 1)*csz*blkx*dsize - meta = {'block': (blkx, ks, 1), 'shared': shared, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} - yield from emit('cstream-ksplit', args, meta) - - # B loading, C preloading, C streaming kernel - args = {'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'desc': f'cstream-preload-c/x{blkx}'} - yield from emit('cstream-preload-c', args, meta) - - # B streaming, C preloading, C accumulation kernel - meta = {'block': (blkx, 1, 1), 'desc': f'bstream-preload-c/x{blkx}'} - yield from emit('bstream-preload-c', args, meta) - - if width > 1: - args = {'dtype': f'{dtype}{width}', 'width': width, - 'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'width': width, - 'desc': f'cstream-width-preload-c/w{width}-x{blkx}'} - yield from emit('cstream-width-preload-c', args, meta) - - meta = {'block': (blkx, 1, 1), 'width': width, - 'desc': f'bstream-width-preload-c/w{width}-x{blkx}'} - yield from emit('bstream-width-preload-c', args, meta) - - for ms in msplits: - # m-split B streaming, C preloading, C accumulation kernel - args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - shared = 2*bsz*blkx*dsize - meta = {'block': (blkx, ms, 1), 'shared': shared, - 'desc': f'bstream-msplit-preload-c/m{ms}-b{bsz}-x{blkx}'} - yield from emit('bstream-msplit-preload-c', args, meta) - - if width > 1: - args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx, - 'dtype': f'{dtype}{width}', 'width': width} + widths = [1] + if self.aligne is not None and self.aligne % 2 == 0: + widths.append(2) + + for width in widths: + wargs = ({'dtype': f'{dtype}{width}', 'width': width} + if width > 1 else {}) + wmeta = {'width': width} if width > 1 else {} + wpfx = f'w{width}-' if width > 1 else '' + + # B loading, C streaming kernel + args = {'blockx': blkx} | wargs + meta = {'block': (blkx, 1, 1), + 'desc': f'cstream/{wpfx}x{blkx}'} | wmeta + yield from emit('cstream', args, meta) + + # B streaming, C accumulation kernel + meta = {'block': (blkx, 1, 1), + 'desc': f'bstream/{wpfx}x{blkx}'} | wmeta + yield from emit('bstream', args, meta) + + for ms in msplits: + # m-split B streaming, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs + shared = 2*bsz*blkx*dsize*width meta = { - 'block': (blkx, ms, 1), 'shared': shared*width, - 'width': width, + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'bstream-msplit/{wpfx}m{ms}-b{bsz}-x{blkx}' + } | wmeta + yield from emit('bstream-msplit', args, meta) + + for ks in ksplits: + # k-split B loading, C streaming kernel + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs + shared = (ks - 1)*csz*blkx*dsize*width + meta = { + 'block': (blkx, ks, 1), 'shared': shared, + 'desc': f'cstream-ksplit/{wpfx}k{ks}-c{csz}-x{blkx}' + } | wmeta + yield from emit('cstream-ksplit', args, meta) + + # B loading, C preloading, C streaming kernel + args = {'blockx': blkx} | wargs + meta = {'block': (blkx, 1, 1), + 'desc': f'cstream-preload-c/{wpfx}x{blkx}'} | wmeta + yield from emit('cstream-preload-c', args, meta) + + # B streaming, C preloading, C accumulation kernel + meta = {'block': (blkx, 1, 1), + 'desc': f'bstream-preload-c/{wpfx}x{blkx}'} | wmeta + yield from emit('bstream-preload-c', args, meta) + + for ms in msplits: + # m-split B streaming, C preloading, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs + shared = 2*bsz*blkx*dsize*width + meta = { + 'block': (blkx, ms, 1), 'shared': shared, 'desc': ( - f'bstream-msplit-width-preload-c/w{width}-' - f'm{ms}-b{bsz}-x{blkx}' + f'bstream-msplit-preload-c/' + f'{wpfx}m{ms}-b{bsz}-x{blkx}' ) - } - yield from emit('bstream-msplit-width-preload-c', args, meta) - - for ks in ksplits: - # k-split B loading, C preloading, C streaming kernel - args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} - shared = (ks - 1)*csz*blkx*dsize - meta = { - 'block': (blkx, ks, 1), 'shared': shared, - 'desc': f'cstream-ksplit-preload-c/k{ks}-c{csz}-x{blkx}' - } - yield from emit('cstream-ksplit-preload-c', args, meta) - - if width > 1: - args = {'ksplit': ks, 'csz': csz, 'blockx': blkx, - 'dtype': f'{dtype}{width}', 'width': width} + } | wmeta + yield from emit('bstream-msplit-preload-c', args, meta) + + for ks in ksplits: + # k-split B loading, C preloading, C streaming kernel + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs + shared = (ks - 1)*csz*blkx*dsize*width meta = { - 'block': (blkx, ks, 1), 'shared': shared*width, - 'width': width, + 'block': (blkx, ks, 1), 'shared': shared, 'desc': ( - f'cstream-ksplit-width-preload-c/w{width}-' - f'k{ks}-c{csz}-x{blkx}' + f'cstream-ksplit-preload-c/' + f'{wpfx}k{ks}-c{csz}-x{blkx}' ) - } - yield from emit('cstream-ksplit-width-preload-c', args, meta) + } | wmeta + yield from emit('cstream-ksplit-preload-c', args, meta) def _process_meta(self, meta): if self.n is not None: diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index b40a0b9..a03a943 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -9,6 +9,67 @@ static inline __device__ ${dtype} make_zero() { return 0; } % endif +% if width == 1: +static inline __device__ ${dtype} +gimmik_vmul(${dtype} a, ${dtype} b) +{ + return a*b; +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return a + b; +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return acc + a*b; +} +% elif width == 2: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); +} +% elif width == 4: +static inline __device__ ${dtype} +gimmik_vmul(${dtype[:-1]} a, ${dtype} b) +{ + return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); +} + +static inline __device__ ${dtype} +gimmik_vadd(${dtype} a, ${dtype} b) +{ + return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} + +static inline __device__ ${dtype} +gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) +{ + // Keep the multiply-add expression visible to the compiler. + return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); +} +% else: +#error "HIP vector helpers only support width=2 or width=4" +% endif + static inline __device__ void nt_store_c(${dtype}* p, ${dtype} v) { @@ -41,4 +102,26 @@ nt_load_c(const ${dtype}* p) % endif } +<% nt_c = context.get('nt_c', True) %> + +static inline __device__ void +store_c(${dtype}* p, ${dtype} v) +{ +% if nt_c: + nt_store_c(p, v); +% else: + *p = v; +% endif +} + +static inline __device__ ${dtype} +load_c(const ${dtype}* p) +{ +% if nt_c: + return nt_load_c(p); +% else: + return *p; +% endif +} + ${next.body()} diff --git a/gimmik/kernels/hip/bstream-msplit-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-preload-c.mako index 8b6f008..6cabeb6 100644 --- a/gimmik/kernels/hip/bstream-msplit-preload-c.mako +++ b/gimmik/kernels/hip/bstream-msplit-preload-c.mako @@ -12,7 +12,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -38,18 +38,18 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endif % endfor + % if beta != 0: ## Preload C values for active rows owned by this m-split lane % for j, jx in enumerate(mx[cid]): % if afix[jx] != -1: - % if beta == 0: - csub[${j}] = make_zero(); - % elif beta == 1: - csub[${j}] = nt_load_c(&c[i + ${jx}*ldc]); + % if beta == 1: + csub[${j}] = load_c(&c[i + ${jx}*ldc]); % else: - csub[${j}] = ${beta}*nt_load_c(&c[i + ${jx}*ldc]); + csub[${j}] = gimmik_vmul(${beta}, load_c(&c[i + ${jx}*ldc])); % endif % endif % endfor + % endif } % endfor __syncthreads(); @@ -72,12 +72,18 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for kx in bchunks[bb]: bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; % for j, jx in enumerate(A[mcx, kx]): - % if jx != 0: - csub[${j}] += ${jx}*bv; + % if beta == 0: + % if jx != 0 and kx == afix[mcx[j]]: + csub[${j}] = gimmik_vmul(${jx}, bv); + % elif jx != 0: + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + % endif + % elif jx != 0: + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); % endif ## If we're done with this dot product then store to global % if kx == alix[mcx[j]]: - nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % endif % endfor % endfor @@ -85,9 +91,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if loop.parent.last: % for j, jx in enumerate(afix): % if jx == -1 and j % msplit == cid and beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako deleted file mode 100644 index 4466f9f..0000000 --- a/gimmik/kernels/hip/bstream-msplit-width-preload-c.mako +++ /dev/null @@ -1,100 +0,0 @@ -<%inherit file='base'/> - -<%include file='vector'/> - -<% -mx = partition(A, into=msplit, by='rows') -bchunks = chunk(bix, bsz) -%> - -__global__ __launch_bounds__(${blockx*msplit}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - int i = blockDim.x*blockIdx.x + threadIdx.x; - - ${dtype} bv, csub[${-(-m // msplit)}]; - __shared__ ${dtype} bsub[2][${bsz}][${blockx}]; - -## Fill the initial shared memory block -% for cid in range(msplit): - if (i < n && threadIdx.y == ${cid}) - { - % for kx in bchunks[0]: - % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; - % endif - % endfor - - ## Preload C values for active rows owned by this m-split lane - % for j, jx in enumerate(mx[cid]): - % if afix[jx] != -1: - % if beta == 0: - csub[${j}] = make_zero(); - % elif beta == 1: - csub[${j}] = nt_load_c(&c[i + ${jx}*ldc]); - % else: - csub[${j}] = gimmik_vmul(${beta}, nt_load_c(&c[i + ${jx}*ldc])); - % endif - % endif - % endfor - } -% endfor - __syncthreads(); - -## Iterate over each row-chunk of B -% for bb in range(len(bchunks)): - ## Iterate over each row-chunk of C - % for cid, mcx in enumerate(mx): - if (i < n && threadIdx.y == ${cid}) - { - ## Start filling the next shared memory block - % if not loop.parent.last: - % for kx in bchunks[bb + 1]: - % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; - % endif - % endfor - % endif - ## Accumulate our dot products - % for kx in bchunks[bb]: - bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; - % for j, jx in enumerate(A[mcx, kx]): - % if jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); - % endif - ## If we're done with this dot product then store to global - % if kx == alix[mcx[j]]: - nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); - % endif - % endfor - % endfor - ## Handle rows of A which are all zero - % if loop.parent.last: - % for j, jx in enumerate(afix): - % if jx == -1 and j % msplit == cid and beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); - % elif jx == -1 and j % msplit == cid and beta != 1: - nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); - % endif - % endfor - % endif - } - % endfor - __syncthreads(); -% endfor -} diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 6470477..35d336e 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -12,7 +12,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -60,17 +60,17 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; % for j, jx in enumerate(A[mcx, kx]): % if jx != 0 and kx == afix[mcx[j]]: - csub[${j}] = ${jx}*bv; + csub[${j}] = gimmik_vmul(${jx}, bv); % elif jx != 0: - csub[${j}] += ${jx}*bv; + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); % endif ## If we're done with this dot product then store to global % if kx == alix[mcx[j]] and beta == 0: - nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % elif kx == alix[mcx[j]] and beta == 1: - nt_store_c(&c[i + ${mcx[j]}*ldc], nt_load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]); + store_c(&c[i + ${mcx[j]}*ldc], gimmik_vadd(load_c(&c[i + ${mcx[j]}*ldc]), csub[${j}])); % elif kx == alix[mcx[j]]: - nt_store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*nt_load_c(&c[i + ${mcx[j]}*ldc])); + store_c(&c[i + ${mcx[j]}*ldc], gimmik_vadd(csub[${j}], gimmik_vmul(${beta}, load_c(&c[i + ${mcx[j]}*ldc])))); % endif % endfor % endfor @@ -78,9 +78,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if loop.parent.last: % for j, jx in enumerate(afix): % if jx == -1 and j % msplit == cid and beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc])*${beta}); + store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream-preload-c.mako b/gimmik/kernels/hip/bstream-preload-c.mako index 30b08f6..095be83 100644 --- a/gimmik/kernels/hip/bstream-preload-c.mako +++ b/gimmik/kernels/hip/bstream-preload-c.mako @@ -7,7 +7,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -24,29 +24,35 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { ${dtype} bv, csub[${m}]; +% if beta != 0: ## Preload C values for rows which will receive a non-zero dot product % for j, jx in enumerate(afix): % if jx != -1: - % if beta == 0: - csub[${j}] = make_zero(); - % elif beta == 1: - csub[${j}] = nt_load_c(&c[i + ${j}*ldc]); + % if beta == 1: + csub[${j}] = load_c(&c[i + ${j}*ldc]); % else: - csub[${j}] = ${beta}*nt_load_c(&c[i + ${j}*ldc]); + csub[${j}] = gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])); % endif % endif % endfor +% endif ## Iterate through the used rows of B % for kx in bix: bv = b[i + ${kx}*ldb]; % for j, jx in enumerate(A[:, kx]): - % if jx != 0: - csub[${j}] += ${jx}*bv; + % if beta == 0: + % if jx != 0 and kx == afix[j]: + csub[${j}] = gimmik_vmul(${jx}, bv); + % elif jx != 0: + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + % endif + % elif jx != 0: + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); % endif ## % if kx == alix[j]: - nt_store_c(&c[i + ${j}*ldc], csub[${j}]); + store_c(&c[i + ${j}*ldc], csub[${j}]); % endif % endfor % endfor @@ -54,9 +60,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Handle rows of A which are all zero % for j, jx in enumerate(afix): % if jx == -1 and beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); % endif % endfor } diff --git a/gimmik/kernels/hip/bstream-width-preload-c.mako b/gimmik/kernels/hip/bstream-width-preload-c.mako deleted file mode 100644 index 2c4e5c5..0000000 --- a/gimmik/kernels/hip/bstream-width-preload-c.mako +++ /dev/null @@ -1,65 +0,0 @@ -<%inherit file='base'/> - -<%include file='vector'/> - -__global__ __launch_bounds__(${blockx}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i < n) - { - ${dtype} bv, csub[${m}]; - -## Preload C values for rows which will receive a non-zero dot product -% for j, jx in enumerate(afix): - % if jx != -1: - % if beta == 0: - csub[${j}] = make_zero(); - % elif beta == 1: - csub[${j}] = nt_load_c(&c[i + ${j}*ldc]); - % else: - csub[${j}] = gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc])); - % endif - % endif -% endfor - -## Iterate through the used rows of B -% for kx in bix: - bv = b[i + ${kx}*ldb]; - % for j, jx in enumerate(A[:, kx]): - % if jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); - % endif - ## - % if kx == alix[j]: - nt_store_c(&c[i + ${j}*ldc], csub[${j}]); - % endif - % endfor -% endfor - -## Handle rows of A which are all zero -% for j, jx in enumerate(afix): - % if jx == -1 and beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); - % elif jx == -1 and beta != 1: - nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); - % endif -% endfor - } -} diff --git a/gimmik/kernels/hip/bstream.mako b/gimmik/kernels/hip/bstream.mako index 9634c73..427dffc 100644 --- a/gimmik/kernels/hip/bstream.mako +++ b/gimmik/kernels/hip/bstream.mako @@ -7,7 +7,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -29,17 +29,17 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bv = b[i + ${kx}*ldb]; % for j, jx in enumerate(A[:, kx]): % if jx != 0 and kx == afix[j]: - csub[${j}] = ${jx}*bv; + csub[${j}] = gimmik_vmul(${jx}, bv); % elif jx != 0: - csub[${j}] += ${jx}*bv; + csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); % endif ## % if kx == alix[j] and beta == 0: - nt_store_c(&c[i + ${j}*ldc], csub[${j}]); + store_c(&c[i + ${j}*ldc], csub[${j}]); % elif kx == alix[j] and beta == 1: - nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc]) + csub[${j}]); + store_c(&c[i + ${j}*ldc], gimmik_vadd(load_c(&c[i + ${j}*ldc]), csub[${j}])); % elif kx == alix[j]: - nt_store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vadd(csub[${j}], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])))); % endif % endfor % endfor @@ -47,9 +47,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Handle rows of A which are all zero % for j, jx in enumerate(afix): % if jx == -1 and beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc])*${beta}); + store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); % endif % endfor } diff --git a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako index 51f1db4..4550700 100644 --- a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako +++ b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako @@ -13,7 +13,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -44,24 +44,27 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endif % endfor <% - dotex = dot(lambda kx: f'bv[{kx}]', A[j, kbx]) + nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] has_dotp = A[j].any() + if nzixs: + first_l_idx, first_kx = nzixs[0] + dotex = f"gimmik_vmul({A[j, first_kx]}, bv[{first_l_idx}])" + for l_idx, kx in nzixs[1:]: + dotex = f"gimmik_vmadd({dotex}, {A[j, kx]}, bv[{l_idx}])" + else: + dotex = 'make_zero()' %> - % if dotex != '0.0': dotp = ${dotex}; - % else: - dotp = make_zero(); - % endif ## Save to a register % if loop.index % ksplit == bid: % if beta == 0: cv[${loop.index // ksplit}] = dotp; % elif beta == 1 and has_dotp: - cv[${loop.index // ksplit}] = nt_load_c(&c[i + ${j}*ldc]); - cv[${loop.index // ksplit}] += dotp; + cv[${loop.index // ksplit}] = load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); % elif has_dotp: - cv[${loop.index // ksplit}] = ${beta}*nt_load_c(&c[i + ${j}*ldc]); - cv[${loop.index // ksplit}] += dotp; + cv[${loop.index // ksplit}] = gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])); + cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); % endif ## Save to shared memory % else: @@ -79,20 +82,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for j in cchunk: % if loop.index % ksplit == bid: <% has_dotp = A[j].any() %> + <% + sum_expr = f"cv[{loop.index // ksplit}]" + for s_idx in range(ksplit - 1): + sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" + %> % if beta == 0: - dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' - for i in range(ksplit - 1))}; - nt_store_c(&c[i + ${j}*ldc], dotp); + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1 and has_dotp: - dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' - for i in range(ksplit - 1))}; - nt_store_c(&c[i + ${j}*ldc], dotp); + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); % elif beta != 1 and has_dotp: - dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' - for i in range(ksplit - 1))}; - nt_store_c(&c[i + ${j}*ldc], dotp); + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); % elif beta != 1: - nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako deleted file mode 100644 index bdac6dc..0000000 --- a/gimmik/kernels/hip/cstream-ksplit-width-preload-c.mako +++ /dev/null @@ -1,114 +0,0 @@ -<%inherit file='base'/> - -<%include file='vector'/> - -<% -kparts = partition(A, ksplit, by='cols') -cchunks = chunk(range(m), csz) -loaded = set() -%> - -__global__ __launch_bounds__(${blockx*ksplit}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - int i = blockDim.x*blockIdx.x + threadIdx.x; - - ${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp; - __shared__ ${dtype} csub[${ksplit - 1}][${csz}][${blockx}]; - -## Iterate over the row-partitions of C -% for cchunk in cchunks: - ## Iterate over the column-partitions of B - % for bid, kbx in enumerate(kparts): - if (i < n && threadIdx.y == ${bid}) - { - ## Evaluate our partial dot products - % for j in cchunk: - ## Load in any missing parts of B - % for kx in kbx: - % if A[j, kx] != 0 and kx not in loaded: - bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> - % endif - % endfor - - ## Expand vectorized partial dot product - <% - nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] - - has_dotp = A[j].any() - if not nzixs: - dotex = 'make_zero()' - else: - first_l_idx, first_kx = nzixs[0] - dotex = f"gimmik_vmul({A[j, first_kx]}, bv[{first_l_idx}])" - for l_idx, kx in nzixs[1:]: - dotex = f"gimmik_vmadd({dotex}, {A[j, kx]}, bv[{l_idx}])" - %> - dotp = ${dotex}; - - ## Save to a register - % if loop.index % ksplit == bid: - % if beta == 0: - cv[${loop.index // ksplit}] = dotp; - % elif beta == 1 and has_dotp: - cv[${loop.index // ksplit}] = nt_load_c(&c[i + ${j}*ldc]); - cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); - % elif has_dotp: - cv[${loop.index // ksplit}] = gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc])); - cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); - % endif - ## Save to shared memory - % else: - csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; - % endif - % endfor - } - % endfor - __syncthreads(); - - ## Sum and output the final set of dot products - % for bid, kbx in enumerate(kparts): - if (i < n && threadIdx.y == ${bid}) - { - % for j in cchunk: - % if loop.index % ksplit == bid: - <% - has_dotp = A[j].any() - sum_expr = f"cv[{loop.index // ksplit}]" - for s_idx in range(ksplit - 1): - sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" - %> - % if beta == 0: - dotp = ${sum_expr}; - nt_store_c(&c[i + ${j}*ldc], dotp); - % elif beta == 1 and has_dotp: - dotp = ${sum_expr}; - nt_store_c(&c[i + ${j}*ldc], dotp); - % elif beta != 1 and has_dotp: - dotp = ${sum_expr}; - nt_store_c(&c[i + ${j}*ldc], dotp); - % elif beta != 1: - nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); - % endif - % endif - % endfor - } - % endfor - __syncthreads(); -% endfor -} diff --git a/gimmik/kernels/hip/cstream-ksplit.mako b/gimmik/kernels/hip/cstream-ksplit.mako index 6fd3210..9c5fc09 100644 --- a/gimmik/kernels/hip/cstream-ksplit.mako +++ b/gimmik/kernels/hip/cstream-ksplit.mako @@ -13,7 +13,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -43,11 +43,17 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> % endif % endfor - % if (dotex := dot(lambda kx: f'bv[{kx}]', A[j, kbx])) != '0.0': + <% + nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] + if nzixs: + first_l_idx, first_kx = nzixs[0] + dotex = f"gimmik_vmul({A[j, first_kx]}, bv[{first_l_idx}])" + for l_idx, kx in nzixs[1:]: + dotex = f"gimmik_vmadd({dotex}, {A[j, kx]}, bv[{l_idx}])" + else: + dotex = 'make_zero()' + %> dotp = ${dotex}; - % else: - dotp = make_zero(); - % endif ## Save to a register % if loop.index % ksplit == bid: cv[${loop.index // ksplit}] = dotp; @@ -66,14 +72,18 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Sum and output the final set of dot products % for j in cchunk: % if loop.index % ksplit == bid: - dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' - for i in range(ksplit - 1))}; + <% + sum_expr = f"cv[{loop.index // ksplit}]" + for s_idx in range(ksplit - 1): + sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" + %> + dotp = ${sum_expr}; % if beta == 0: - nt_store_c(&c[i + ${j}*ldc], dotp); + store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: - nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc]) + dotp); + store_c(&c[i + ${j}*ldc], gimmik_vadd(load_c(&c[i + ${j}*ldc]), dotp)); % else: - nt_store_c(&c[i + ${j}*ldc], dotp + ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vadd(dotp, gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])))); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream-preload-c.mako b/gimmik/kernels/hip/cstream-preload-c.mako index 041e674..eebb602 100644 --- a/gimmik/kernels/hip/cstream-preload-c.mako +++ b/gimmik/kernels/hip/cstream-preload-c.mako @@ -9,7 +9,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -26,24 +26,34 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) if (i < n) { % for j, jx in enumerate(A): - % if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': + <% + nzixs = [kx for kx, val in enumerate(jx) if val != 0] + if nzixs: + first_kx = nzixs[0] + dotex = f"gimmik_vmul({jx[first_kx]}, b[i + {first_kx}*ldb])" + for kx in nzixs[1:]: + dotex = f"gimmik_vmadd({dotex}, {jx[kx]}, b[i + {kx}*ldb])" + else: + dotex = 'make_zero()' + %> + % if nzixs: % if beta == 0: dotp = ${dotex}; - nt_store_c(&c[i + ${j}*ldc], dotp); + store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: - dotp = nt_load_c(&c[i + ${j}*ldc]); - dotp += ${dotex}; - nt_store_c(&c[i + ${j}*ldc], dotp); + dotp = load_c(&c[i + ${j}*ldc]); + dotp = gimmik_vadd(dotp, ${dotex}); + store_c(&c[i + ${j}*ldc], dotp); % else: - dotp = ${beta}*nt_load_c(&c[i + ${j}*ldc]); - dotp += ${dotex}; - nt_store_c(&c[i + ${j}*ldc], dotp); + dotp = gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])); + dotp = gimmik_vadd(dotp, ${dotex}); + store_c(&c[i + ${j}*ldc], dotp); % endif % else: % if beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); + store_c(&c[i + ${j}*ldc], make_zero()); % elif beta != 1: - nt_store_c(&c[i + ${j}*ldc], ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream-width-preload-c.mako b/gimmik/kernels/hip/cstream-width-preload-c.mako deleted file mode 100644 index 86acfcb..0000000 --- a/gimmik/kernels/hip/cstream-width-preload-c.mako +++ /dev/null @@ -1,66 +0,0 @@ -<%inherit file='base'/> - -<%include file='vector'/> - -__global__ __launch_bounds__(${blockx}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - const int i = blockDim.x*blockIdx.x + threadIdx.x; - ${dtype} bv, dotp; - - if (i < n) - { -% for j, row in enumerate(A): - <% - nzixs = [kx for kx, val in enumerate(row) if val != 0] - %> - % if nzixs: - % if beta == 0: - <% first_kx = nzixs[0] %> - bv = b[i + ${first_kx}*ldb]; - dotp = gimmik_vmul(${row[first_kx]}, bv); - % for kx in nzixs[1:]: - bv = b[i + ${kx}*ldb]; - dotp = gimmik_vmadd(dotp, ${row[kx]}, bv); - % endfor - nt_store_c(&c[i + ${j}*ldc], dotp); - % elif beta == 1: - dotp = nt_load_c(&c[i + ${j}*ldc]); - % for kx in nzixs: - bv = b[i + ${kx}*ldb]; - dotp = gimmik_vmadd(dotp, ${row[kx]}, bv); - % endfor - nt_store_c(&c[i + ${j}*ldc], dotp); - % else: - dotp = gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc])); - % for kx in nzixs: - bv = b[i + ${kx}*ldb]; - dotp = gimmik_vmadd(dotp, ${row[kx]}, bv); - % endfor - nt_store_c(&c[i + ${j}*ldc], dotp); - % endif - % else: - % if beta == 0: - nt_store_c(&c[i + ${j}*ldc], make_zero()); - % elif beta != 1: - nt_store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, nt_load_c(&c[i + ${j}*ldc]))); - % endif - % endif -% endfor - } -} diff --git a/gimmik/kernels/hip/cstream.mako b/gimmik/kernels/hip/cstream.mako index 0651e87..1b7b312 100644 --- a/gimmik/kernels/hip/cstream.mako +++ b/gimmik/kernels/hip/cstream.mako @@ -9,7 +9,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -26,17 +26,26 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) if (i < n) { % for j, jx in enumerate(A): - % if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': + <% + nzixs = [kx for kx, val in enumerate(jx) if val != 0] + if nzixs: + first_kx = nzixs[0] + dotex = f"gimmik_vmul({jx[first_kx]}, b[i + {first_kx}*ldb])" + for kx in nzixs[1:]: + dotex = f"gimmik_vmadd({dotex}, {jx[kx]}, b[i + {kx}*ldb])" + else: + dotex = 'make_zero()' + %> dotp = ${dotex}; - % else: + % if not nzixs: dotp = make_zero(); % endif % if beta == 0: - nt_store_c(&c[i + ${j}*ldc], dotp); - % elif beta == 1 and dotex != '0.0': - nt_store_c(&c[i + ${j}*ldc], nt_load_c(&c[i + ${j}*ldc]) + dotp); + store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1 and nzixs: + store_c(&c[i + ${j}*ldc], gimmik_vadd(load_c(&c[i + ${j}*ldc]), dotp)); % else: - nt_store_c(&c[i + ${j}*ldc], dotp + ${beta}*nt_load_c(&c[i + ${j}*ldc])); + store_c(&c[i + ${j}*ldc], gimmik_vadd(dotp, gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])))); % endif % endfor } diff --git a/gimmik/kernels/hip/vector.mako b/gimmik/kernels/hip/vector.mako deleted file mode 100644 index 268d6ab..0000000 --- a/gimmik/kernels/hip/vector.mako +++ /dev/null @@ -1,41 +0,0 @@ -% if width == 2: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); -} -% elif width == 4: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); -} - -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); -} -% else: -#error "HIP vector helpers only support width=2 or width=4" -% endif From e9b921a99f7ce476c98fc59bace1a4e72c26259b Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Tue, 23 Jun 2026 05:14:11 -0500 Subject: [PATCH 07/25] Use blockx launch bounds for HIP cstream preload --- gimmik/kernels/hip/cstream-preload-c.mako | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gimmik/kernels/hip/cstream-preload-c.mako b/gimmik/kernels/hip/cstream-preload-c.mako index eebb602..a8d7e31 100644 --- a/gimmik/kernels/hip/cstream-preload-c.mako +++ b/gimmik/kernels/hip/cstream-preload-c.mako @@ -2,7 +2,7 @@ <% ksplit = 2 if m < 36 else 1 %> -__global__ __launch_bounds__(128) void +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, From 2aa2577a4a99a5479f57679c62a9debacf083ef6 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Tue, 23 Jun 2026 21:49:28 -0500 Subject: [PATCH 08/25] Always use non-temporal C accesses for HIP --- gimmik/kernels/hip/base.mako | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index a03a943..831b8e0 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -102,26 +102,16 @@ nt_load_c(const ${dtype}* p) % endif } -<% nt_c = context.get('nt_c', True) %> - static inline __device__ void store_c(${dtype}* p, ${dtype} v) { -% if nt_c: nt_store_c(p, v); -% else: - *p = v; -% endif } static inline __device__ ${dtype} load_c(const ${dtype}* p) { -% if nt_c: return nt_load_c(p); -% else: - return *p; -% endif } ${next.body()} From be1c1dbb58e34353e73d7b863aa6eceecf9a0880 Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Wed, 24 Jun 2026 22:47:02 +0800 Subject: [PATCH 09/25] feat(hip): add non-temporal B-load (NTB) variants for bstream-msplit On memory-bound operators the B matrix is read once from HBM and reused only within a work-group via LDS -- it is never re-read across blocks. A normal global load still allocates B's line in L2, which is pure overhead: the line is never reused, it evicts genuinely-reusable data, and it adds cache-allocate/eviction traffic. This is the read-side mirror of the non-temporal C store we already use. NTB loads B with a non-temporal load (load_b -> __builtin_nontemporal path) so B bypasses L2 instead of polluting it. It moves the same number of bytes but keeps the cache clean, raising effective bandwidth. Implemented as a flag on the existing templates rather than new files: - base.mako: add a load_b() wrapper (non-temporal B load). - bstream-msplit{,-preload-c}.mako: gate the B read behind an `ntload` flag (context.get('ntload', False)); renders byte-identically to the plain kernel when the flag is absent. - hip.py: emit `*-ntb` variants by passing ntload=True inside the existing width loop, so NTB combines with width (w1/w2) automatically. Backward-compatible (plain variants unchanged) and CDNA-gated like the other tuned variants. On MI300X (gfx942) NTB passes the accuracy check (~1e-15) and wins the autotune in the majority of memory-bound cases (~+4.5% bandwidth on those), being chosen over the plain bstream-msplit. --- gimmik/hip.py | 17 +++++++++++++++++ gimmik/kernels/hip/base.mako | 9 +++++++++ .../kernels/hip/bstream-msplit-preload-c.mako | 6 ++++-- gimmik/kernels/hip/bstream-msplit.mako | 6 ++++-- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 57fa394..860ed6c 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -78,6 +78,13 @@ def emit(name, args, meta): } | wmeta yield from emit('bstream-msplit', args, meta) + # non-temporal B-load variant: B is read-once -> skip L2 alloc + nmeta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'bstream-msplit-ntb/{wpfx}m{ms}-b{bsz}-x{blkx}' + } | wmeta + yield from emit('bstream-msplit', args | {'ntload': True}, nmeta) + for ks in ksplits: # k-split B loading, C streaming kernel args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs @@ -112,6 +119,16 @@ def emit(name, args, meta): } | wmeta yield from emit('bstream-msplit-preload-c', args, meta) + # non-temporal B-load variant + nmeta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': ( + f'bstream-msplit-preload-c-ntb/' + f'{wpfx}m{ms}-b{bsz}-x{blkx}' + ) + } | wmeta + yield from emit('bstream-msplit-preload-c', args | {'ntload': True}, nmeta) + for ks in ksplits: # k-split B loading, C preloading, C streaming kernel args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index 831b8e0..5291538 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -114,4 +114,13 @@ load_c(const ${dtype}* p) return nt_load_c(p); } +static inline __device__ ${dtype} +load_b(const ${dtype}* p) +{ + // B is read-once (reused only within the block via LDS, never re-read across + // blocks), so load it non-temporally to avoid polluting L2 -- the read-side + // twin of the non-temporal C store. + return nt_load_c(p); +} + ${next.body()} diff --git a/gimmik/kernels/hip/bstream-msplit-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-preload-c.mako index 6cabeb6..36783c0 100644 --- a/gimmik/kernels/hip/bstream-msplit-preload-c.mako +++ b/gimmik/kernels/hip/bstream-msplit-preload-c.mako @@ -3,6 +3,8 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) +ntload = context.get('ntload', False) +bload = (lambda kx: f'load_b(&b[i + {kx}*ldb])') if ntload else (lambda kx: f'b[i + {kx}*ldb]') %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -34,7 +36,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { % for kx in bchunks[0]: % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[0][${loop.index}][threadIdx.x] = ${bload(kx)}; % endif % endfor @@ -64,7 +66,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if not loop.parent.last: % for kx in bchunks[bb + 1]: % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = ${bload(kx)}; % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 35d336e..30f2c57 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -3,6 +3,8 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) +ntload = context.get('ntload', False) +bload = (lambda kx: f'load_b(&b[i + {kx}*ldb])') if ntload else (lambda kx: f'b[i + {kx}*ldb]') %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -34,7 +36,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { % for kx in bchunks[0]: % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[0][${loop.index}][threadIdx.x] = ${bload(kx)}; % endif % endfor } @@ -51,7 +53,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if not loop.parent.last: % for kx in bchunks[bb + 1]: % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = ${bload(kx)}; % endif % endfor % endif From 280e948c90da7643b511c5064ef9f51c784d27ba Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Wed, 24 Jun 2026 22:39:34 -0500 Subject: [PATCH 10/25] Use non-temporal B loads by default for HIP --- gimmik/hip.py | 17 ---- gimmik/kernels/hip/base.mako | 94 +++++-------------- .../kernels/hip/bstream-msplit-preload-c.mako | 16 ++-- gimmik/kernels/hip/bstream-msplit.mako | 16 ++-- gimmik/kernels/hip/bstream-preload-c.mako | 12 +-- gimmik/kernels/hip/bstream.mako | 12 +-- .../kernels/hip/cstream-ksplit-preload-c.mako | 14 +-- gimmik/kernels/hip/cstream-ksplit.mako | 10 +- gimmik/kernels/hip/cstream-preload-c.mako | 12 +-- gimmik/kernels/hip/cstream.mako | 8 +- 10 files changed, 72 insertions(+), 139 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 860ed6c..57fa394 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -78,13 +78,6 @@ def emit(name, args, meta): } | wmeta yield from emit('bstream-msplit', args, meta) - # non-temporal B-load variant: B is read-once -> skip L2 alloc - nmeta = { - 'block': (blkx, ms, 1), 'shared': shared, - 'desc': f'bstream-msplit-ntb/{wpfx}m{ms}-b{bsz}-x{blkx}' - } | wmeta - yield from emit('bstream-msplit', args | {'ntload': True}, nmeta) - for ks in ksplits: # k-split B loading, C streaming kernel args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs @@ -119,16 +112,6 @@ def emit(name, args, meta): } | wmeta yield from emit('bstream-msplit-preload-c', args, meta) - # non-temporal B-load variant - nmeta = { - 'block': (blkx, ms, 1), 'shared': shared, - 'desc': ( - f'bstream-msplit-preload-c-ntb/' - f'{wpfx}m{ms}-b{bsz}-x{blkx}' - ) - } | wmeta - yield from emit('bstream-msplit-preload-c', args | {'ntload': True}, nmeta) - for ks in ksplits: # k-split B loading, C preloading, C streaming kernel args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index 5291538..95e622f 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -1,77 +1,34 @@ % if dtype.endswith('4'): -static inline __device__ ${dtype} make_zero() -{ return make_${dtype}(0, 0, 0, 0); } -% elif dtype.endswith('2'): -static inline __device__ ${dtype} make_zero() -{ return make_${dtype}(0, 0); } -% else: -static inline __device__ ${dtype} make_zero() -{ return 0; } -% endif +inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) +{ return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } -% if width == 1: -static inline __device__ ${dtype} -gimmik_vmul(${dtype} a, ${dtype} b) -{ - return a*b; -} +inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) +{ return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); } -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return a + b; -} - -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return acc + a*b; -} -% elif width == 2: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y); -} +inline __device__ void operator+=(${dtype} &a, ${dtype} b) +{ a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; } -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y); -} +inline __device__ ${dtype} make_zero() +{ return make_${dtype}(0, 0, 0, 0); } +% elif dtype.endswith('2'): +inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) +{ return make_${dtype}(a.x + b.x, a.y + b.y); } -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y); -} -% elif width == 4: -static inline __device__ ${dtype} -gimmik_vmul(${dtype[:-1]} a, ${dtype} b) -{ - return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); -} +inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) +{ return make_${dtype}(a*b.x, a*b.y); } -static inline __device__ ${dtype} -gimmik_vadd(${dtype} a, ${dtype} b) -{ - return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); -} +inline __device__ void operator+=(${dtype} &a, ${dtype} b) +{ a.x += b.x; a.y += b.y; } -static inline __device__ ${dtype} -gimmik_vmadd(${dtype} acc, ${dtype[:-1]} a, ${dtype} b) -{ - // Keep the multiply-add expression visible to the compiler. - return make_${dtype}(acc.x + a*b.x, acc.y + a*b.y, acc.z + a*b.z, acc.w + a*b.w); -} +inline __device__ ${dtype} make_zero() +{ return make_${dtype}(0, 0); } % else: -#error "HIP vector helpers only support width=2 or width=4" +inline __device__ ${dtype} make_zero() +{ return 0; } % endif static inline __device__ void -nt_store_c(${dtype}* p, ${dtype} v) +nt_store(${dtype}* p, ${dtype} v) { % if dtype.endswith('4'): __builtin_nontemporal_store(v.x, &p->x); @@ -87,7 +44,7 @@ nt_store_c(${dtype}* p, ${dtype} v) } static inline __device__ ${dtype} -nt_load_c(const ${dtype}* p) +nt_load(const ${dtype}* p) { % if dtype.endswith('4'): return make_${dtype}(__builtin_nontemporal_load(&p->x), @@ -105,22 +62,19 @@ nt_load_c(const ${dtype}* p) static inline __device__ void store_c(${dtype}* p, ${dtype} v) { - nt_store_c(p, v); + nt_store(p, v); } static inline __device__ ${dtype} load_c(const ${dtype}* p) { - return nt_load_c(p); + return nt_load(p); } static inline __device__ ${dtype} load_b(const ${dtype}* p) { - // B is read-once (reused only within the block via LDS, never re-read across - // blocks), so load it non-temporally to avoid polluting L2 -- the read-side - // twin of the non-temporal C store. - return nt_load_c(p); + return nt_load(p); } ${next.body()} diff --git a/gimmik/kernels/hip/bstream-msplit-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-preload-c.mako index 36783c0..e58fa43 100644 --- a/gimmik/kernels/hip/bstream-msplit-preload-c.mako +++ b/gimmik/kernels/hip/bstream-msplit-preload-c.mako @@ -3,8 +3,6 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) -ntload = context.get('ntload', False) -bload = (lambda kx: f'load_b(&b[i + {kx}*ldb])') if ntload else (lambda kx: f'b[i + {kx}*ldb]') %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -36,7 +34,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { % for kx in bchunks[0]: % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = ${bload(kx)}; + bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor @@ -47,7 +45,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if beta == 1: csub[${j}] = load_c(&c[i + ${jx}*ldc]); % else: - csub[${j}] = gimmik_vmul(${beta}, load_c(&c[i + ${jx}*ldc])); + csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]); % endif % endif % endfor @@ -66,7 +64,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if not loop.parent.last: % for kx in bchunks[bb + 1]: % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = ${bload(kx)}; + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor % endif @@ -76,12 +74,12 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for j, jx in enumerate(A[mcx, kx]): % if beta == 0: % if jx != 0 and kx == afix[mcx[j]]: - csub[${j}] = gimmik_vmul(${jx}, bv); + csub[${j}] = ${jx}*bv; % elif jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + csub[${j}] += ${jx}*bv; % endif % elif jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + csub[${j}] += ${jx}*bv; % endif ## If we're done with this dot product then store to global % if kx == alix[mcx[j]]: @@ -95,7 +93,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if jx == -1 and j % msplit == cid and beta == 0: store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 30f2c57..0681783 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -3,8 +3,6 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) -ntload = context.get('ntload', False) -bload = (lambda kx: f'load_b(&b[i + {kx}*ldb])') if ntload else (lambda kx: f'b[i + {kx}*ldb]') %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -36,7 +34,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { % for kx in bchunks[0]: % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = ${bload(kx)}; + bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor } @@ -53,7 +51,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if not loop.parent.last: % for kx in bchunks[bb + 1]: % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = ${bload(kx)}; + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor % endif @@ -62,17 +60,17 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; % for j, jx in enumerate(A[mcx, kx]): % if jx != 0 and kx == afix[mcx[j]]: - csub[${j}] = gimmik_vmul(${jx}, bv); + csub[${j}] = ${jx}*bv; % elif jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + csub[${j}] += ${jx}*bv; % endif ## If we're done with this dot product then store to global % if kx == alix[mcx[j]] and beta == 0: store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % elif kx == alix[mcx[j]] and beta == 1: - store_c(&c[i + ${mcx[j]}*ldc], gimmik_vadd(load_c(&c[i + ${mcx[j]}*ldc]), csub[${j}])); + store_c(&c[i + ${mcx[j]}*ldc], load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]); % elif kx == alix[mcx[j]]: - store_c(&c[i + ${mcx[j]}*ldc], gimmik_vadd(csub[${j}], gimmik_vmul(${beta}, load_c(&c[i + ${mcx[j]}*ldc])))); + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${mcx[j]}*ldc])); % endif % endfor % endfor @@ -82,7 +80,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if jx == -1 and j % msplit == cid and beta == 0: store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream-preload-c.mako b/gimmik/kernels/hip/bstream-preload-c.mako index 095be83..a7140c6 100644 --- a/gimmik/kernels/hip/bstream-preload-c.mako +++ b/gimmik/kernels/hip/bstream-preload-c.mako @@ -31,7 +31,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if beta == 1: csub[${j}] = load_c(&c[i + ${j}*ldc]); % else: - csub[${j}] = gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])); + csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]); % endif % endif % endfor @@ -39,16 +39,16 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Iterate through the used rows of B % for kx in bix: - bv = b[i + ${kx}*ldb]; + bv = load_b(&b[i + ${kx}*ldb]); % for j, jx in enumerate(A[:, kx]): % if beta == 0: % if jx != 0 and kx == afix[j]: - csub[${j}] = gimmik_vmul(${jx}, bv); + csub[${j}] = ${jx}*bv; % elif jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + csub[${j}] += ${jx}*bv; % endif % elif jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + csub[${j}] += ${jx}*bv; % endif ## % if kx == alix[j]: @@ -62,7 +62,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if jx == -1 and beta == 0: store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor } diff --git a/gimmik/kernels/hip/bstream.mako b/gimmik/kernels/hip/bstream.mako index 427dffc..10df790 100644 --- a/gimmik/kernels/hip/bstream.mako +++ b/gimmik/kernels/hip/bstream.mako @@ -26,20 +26,20 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Iterare through the used rows of B % for kx in bix: - bv = b[i + ${kx}*ldb]; + bv = load_b(&b[i + ${kx}*ldb]); % for j, jx in enumerate(A[:, kx]): % if jx != 0 and kx == afix[j]: - csub[${j}] = gimmik_vmul(${jx}, bv); + csub[${j}] = ${jx}*bv; % elif jx != 0: - csub[${j}] = gimmik_vmadd(csub[${j}], ${jx}, bv); + csub[${j}] += ${jx}*bv; % endif ## % if kx == alix[j] and beta == 0: store_c(&c[i + ${j}*ldc], csub[${j}]); % elif kx == alix[j] and beta == 1: - store_c(&c[i + ${j}*ldc], gimmik_vadd(load_c(&c[i + ${j}*ldc]), csub[${j}])); + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + csub[${j}]); % elif kx == alix[j]: - store_c(&c[i + ${j}*ldc], gimmik_vadd(csub[${j}], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])))); + store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endfor @@ -49,7 +49,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if jx == -1 and beta == 0: store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor } diff --git a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako index 4550700..b922666 100644 --- a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako +++ b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako @@ -48,9 +48,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) has_dotp = A[j].any() if nzixs: first_l_idx, first_kx = nzixs[0] - dotex = f"gimmik_vmul({A[j, first_kx]}, bv[{first_l_idx}])" + dotex = f"{A[j, first_kx]}*bv[{first_l_idx}]" for l_idx, kx in nzixs[1:]: - dotex = f"gimmik_vmadd({dotex}, {A[j, kx]}, bv[{l_idx}])" + dotex = f"{dotex} + {A[j, kx]}*bv[{l_idx}]" else: dotex = 'make_zero()' %> @@ -61,10 +61,10 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) cv[${loop.index // ksplit}] = dotp; % elif beta == 1 and has_dotp: cv[${loop.index // ksplit}] = load_c(&c[i + ${j}*ldc]); - cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); + cv[${loop.index // ksplit}] += dotp; % elif has_dotp: - cv[${loop.index // ksplit}] = gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])); - cv[${loop.index // ksplit}] = gimmik_vadd(cv[${loop.index // ksplit}], dotp); + cv[${loop.index // ksplit}] = ${beta}*load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; % endif ## Save to shared memory % else: @@ -85,7 +85,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) <% sum_expr = f"cv[{loop.index // ksplit}]" for s_idx in range(ksplit - 1): - sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" + sum_expr = f"{sum_expr} + csub[{s_idx}][{loop.index}][threadIdx.x]" %> % if beta == 0: dotp = ${sum_expr}; @@ -97,7 +97,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) dotp = ${sum_expr}; store_c(&c[i + ${j}*ldc], dotp); % elif beta != 1: - store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream-ksplit.mako b/gimmik/kernels/hip/cstream-ksplit.mako index 9c5fc09..5fc6711 100644 --- a/gimmik/kernels/hip/cstream-ksplit.mako +++ b/gimmik/kernels/hip/cstream-ksplit.mako @@ -47,9 +47,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] if nzixs: first_l_idx, first_kx = nzixs[0] - dotex = f"gimmik_vmul({A[j, first_kx]}, bv[{first_l_idx}])" + dotex = f"{A[j, first_kx]}*bv[{first_l_idx}]" for l_idx, kx in nzixs[1:]: - dotex = f"gimmik_vmadd({dotex}, {A[j, kx]}, bv[{l_idx}])" + dotex = f"{dotex} + {A[j, kx]}*bv[{l_idx}]" else: dotex = 'make_zero()' %> @@ -75,15 +75,15 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) <% sum_expr = f"cv[{loop.index // ksplit}]" for s_idx in range(ksplit - 1): - sum_expr = f"gimmik_vadd({sum_expr}, csub[{s_idx}][{loop.index}][threadIdx.x])" + sum_expr = f"{sum_expr} + csub[{s_idx}][{loop.index}][threadIdx.x]" %> dotp = ${sum_expr}; % if beta == 0: store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: - store_c(&c[i + ${j}*ldc], gimmik_vadd(load_c(&c[i + ${j}*ldc]), dotp)); + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: - store_c(&c[i + ${j}*ldc], gimmik_vadd(dotp, gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])))); + store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream-preload-c.mako b/gimmik/kernels/hip/cstream-preload-c.mako index a8d7e31..c9a83b0 100644 --- a/gimmik/kernels/hip/cstream-preload-c.mako +++ b/gimmik/kernels/hip/cstream-preload-c.mako @@ -30,9 +30,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) nzixs = [kx for kx, val in enumerate(jx) if val != 0] if nzixs: first_kx = nzixs[0] - dotex = f"gimmik_vmul({jx[first_kx]}, b[i + {first_kx}*ldb])" + dotex = f"{jx[first_kx]}*b[i + {first_kx}*ldb]" for kx in nzixs[1:]: - dotex = f"gimmik_vmadd({dotex}, {jx[kx]}, b[i + {kx}*ldb])" + dotex = f"{dotex} + {jx[kx]}*b[i + {kx}*ldb]" else: dotex = 'make_zero()' %> @@ -42,18 +42,18 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: dotp = load_c(&c[i + ${j}*ldc]); - dotp = gimmik_vadd(dotp, ${dotex}); + dotp += ${dotex}; store_c(&c[i + ${j}*ldc], dotp); % else: - dotp = gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])); - dotp = gimmik_vadd(dotp, ${dotex}); + dotp = ${beta}*load_c(&c[i + ${j}*ldc]); + dotp += ${dotex}; store_c(&c[i + ${j}*ldc], dotp); % endif % else: % if beta == 0: store_c(&c[i + ${j}*ldc], make_zero()); % elif beta != 1: - store_c(&c[i + ${j}*ldc], gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc]))); + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream.mako b/gimmik/kernels/hip/cstream.mako index 1b7b312..4ea995e 100644 --- a/gimmik/kernels/hip/cstream.mako +++ b/gimmik/kernels/hip/cstream.mako @@ -30,9 +30,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) nzixs = [kx for kx, val in enumerate(jx) if val != 0] if nzixs: first_kx = nzixs[0] - dotex = f"gimmik_vmul({jx[first_kx]}, b[i + {first_kx}*ldb])" + dotex = f"{jx[first_kx]}*b[i + {first_kx}*ldb]" for kx in nzixs[1:]: - dotex = f"gimmik_vmadd({dotex}, {jx[kx]}, b[i + {kx}*ldb])" + dotex = f"{dotex} + {jx[kx]}*b[i + {kx}*ldb]" else: dotex = 'make_zero()' %> @@ -43,9 +43,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if beta == 0: store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1 and nzixs: - store_c(&c[i + ${j}*ldc], gimmik_vadd(load_c(&c[i + ${j}*ldc]), dotp)); + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: - store_c(&c[i + ${j}*ldc], gimmik_vadd(dotp, gimmik_vmul(${beta}, load_c(&c[i + ${j}*ldc])))); + store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor } From c06216d0e09f76313acf0cf53cccd5529d7f6eab Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Wed, 24 Jun 2026 23:00:57 -0500 Subject: [PATCH 11/25] Make HIP preload-C a template option --- gimmik/hip.py | 11 +- .../kernels/hip/bstream-msplit-preload-c.mako | 104 ----------------- gimmik/kernels/hip/bstream-msplit.mako | 22 +++- gimmik/kernels/hip/bstream-preload-c.mako | 69 ----------- gimmik/kernels/hip/bstream.mako | 25 +++- .../kernels/hip/cstream-ksplit-preload-c.mako | 108 ------------------ gimmik/kernels/hip/cstream-ksplit.mako | 28 ++++- gimmik/kernels/hip/cstream-preload-c.mako | 61 ---------- gimmik/kernels/hip/cstream.mako | 25 +++- 9 files changed, 96 insertions(+), 357 deletions(-) delete mode 100644 gimmik/kernels/hip/bstream-msplit-preload-c.mako delete mode 100644 gimmik/kernels/hip/bstream-preload-c.mako delete mode 100644 gimmik/kernels/hip/cstream-ksplit-preload-c.mako delete mode 100644 gimmik/kernels/hip/cstream-preload-c.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index 57fa394..1fceb01 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -19,6 +19,9 @@ def emit(name, args, meta): if threads <= max_block_threads and shared <= max_shared: yield (name, args, meta) + def emit_preload(name, args, meta): + yield from emit(name, args | {'preload': True}, meta) + blkx = self.basemeta['block'][0] # B loading, C streaming kernel @@ -92,12 +95,12 @@ def emit(name, args, meta): args = {'blockx': blkx} | wargs meta = {'block': (blkx, 1, 1), 'desc': f'cstream-preload-c/{wpfx}x{blkx}'} | wmeta - yield from emit('cstream-preload-c', args, meta) + yield from emit_preload('cstream', args, meta) # B streaming, C preloading, C accumulation kernel meta = {'block': (blkx, 1, 1), 'desc': f'bstream-preload-c/{wpfx}x{blkx}'} | wmeta - yield from emit('bstream-preload-c', args, meta) + yield from emit_preload('bstream', args, meta) for ms in msplits: # m-split B streaming, C preloading, C accumulation kernel @@ -110,7 +113,7 @@ def emit(name, args, meta): f'{wpfx}m{ms}-b{bsz}-x{blkx}' ) } | wmeta - yield from emit('bstream-msplit-preload-c', args, meta) + yield from emit_preload('bstream-msplit', args, meta) for ks in ksplits: # k-split B loading, C preloading, C streaming kernel @@ -123,7 +126,7 @@ def emit(name, args, meta): f'{wpfx}k{ks}-c{csz}-x{blkx}' ) } | wmeta - yield from emit('cstream-ksplit-preload-c', args, meta) + yield from emit_preload('cstream-ksplit', args, meta) def _process_meta(self, meta): if self.n is not None: diff --git a/gimmik/kernels/hip/bstream-msplit-preload-c.mako b/gimmik/kernels/hip/bstream-msplit-preload-c.mako deleted file mode 100644 index e58fa43..0000000 --- a/gimmik/kernels/hip/bstream-msplit-preload-c.mako +++ /dev/null @@ -1,104 +0,0 @@ -<%inherit file='base'/> - -<% -mx = partition(A, into=msplit, by='rows') -bchunks = chunk(bix, bsz) -%> - -__global__ __launch_bounds__(${blockx*msplit}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - int i = blockDim.x*blockIdx.x + threadIdx.x; - - ${dtype} bv, csub[${-(-m // msplit)}]; - __shared__ ${dtype} bsub[2][${bsz}][${blockx}]; - -## Fill the initial shared memory block -% for cid in range(msplit): - if (i < n && threadIdx.y == ${cid}) - { - % for kx in bchunks[0]: - % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); - % endif - % endfor - - % if beta != 0: - ## Preload C values for active rows owned by this m-split lane - % for j, jx in enumerate(mx[cid]): - % if afix[jx] != -1: - % if beta == 1: - csub[${j}] = load_c(&c[i + ${jx}*ldc]); - % else: - csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]); - % endif - % endif - % endfor - % endif - } -% endfor - __syncthreads(); - -## Iterate over each row-chunk of B -% for bb in range(len(bchunks)): - ## Iterate over each row-chunk of C - % for cid, mcx in enumerate(mx): - if (i < n && threadIdx.y == ${cid}) - { - ## Start filling the next shared memory block - % if not loop.parent.last: - % for kx in bchunks[bb + 1]: - % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); - % endif - % endfor - % endif - ## Accumulate our dot products - % for kx in bchunks[bb]: - bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; - % for j, jx in enumerate(A[mcx, kx]): - % if beta == 0: - % if jx != 0 and kx == afix[mcx[j]]: - csub[${j}] = ${jx}*bv; - % elif jx != 0: - csub[${j}] += ${jx}*bv; - % endif - % elif jx != 0: - csub[${j}] += ${jx}*bv; - % endif - ## If we're done with this dot product then store to global - % if kx == alix[mcx[j]]: - store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); - % endif - % endfor - % endfor - ## Handle rows of A which are all zero - % if loop.parent.last: - % for j, jx in enumerate(afix): - % if jx == -1 and j % msplit == cid and beta == 0: - store_c(&c[i + ${j}*ldc], make_zero()); - % elif jx == -1 and j % msplit == cid and beta != 1: - store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); - % endif - % endfor - % endif - } - % endfor - __syncthreads(); -% endfor -} diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 0681783..52853f4 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -3,6 +3,7 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) +preload = context.get('preload', False) %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -37,6 +38,19 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor + + % if preload and beta != 0: + ## Preload C values for active rows owned by this m-split lane + % for j, jx in enumerate(mx[cid]): + % if afix[jx] != -1: + % if beta == 1: + csub[${j}] = load_c(&c[i + ${jx}*ldc]); + % else: + csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]); + % endif + % endif + % endfor + % endif } % endfor __syncthreads(); @@ -59,13 +73,17 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for kx in bchunks[bb]: bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; % for j, jx in enumerate(A[mcx, kx]): - % if jx != 0 and kx == afix[mcx[j]]: + % if preload and beta != 0 and jx != 0: + csub[${j}] += ${jx}*bv; + % elif jx != 0 and kx == afix[mcx[j]]: csub[${j}] = ${jx}*bv; % elif jx != 0: csub[${j}] += ${jx}*bv; % endif ## If we're done with this dot product then store to global - % if kx == alix[mcx[j]] and beta == 0: + % if preload and kx == alix[mcx[j]]: + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + % elif kx == alix[mcx[j]] and beta == 0: store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % elif kx == alix[mcx[j]] and beta == 1: store_c(&c[i + ${mcx[j]}*ldc], load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]); diff --git a/gimmik/kernels/hip/bstream-preload-c.mako b/gimmik/kernels/hip/bstream-preload-c.mako deleted file mode 100644 index a7140c6..0000000 --- a/gimmik/kernels/hip/bstream-preload-c.mako +++ /dev/null @@ -1,69 +0,0 @@ -<%inherit file='base'/> - -__global__ __launch_bounds__(${blockx}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i < n) - { - ${dtype} bv, csub[${m}]; - -% if beta != 0: -## Preload C values for rows which will receive a non-zero dot product -% for j, jx in enumerate(afix): - % if jx != -1: - % if beta == 1: - csub[${j}] = load_c(&c[i + ${j}*ldc]); - % else: - csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]); - % endif - % endif -% endfor -% endif - -## Iterate through the used rows of B -% for kx in bix: - bv = load_b(&b[i + ${kx}*ldb]); - % for j, jx in enumerate(A[:, kx]): - % if beta == 0: - % if jx != 0 and kx == afix[j]: - csub[${j}] = ${jx}*bv; - % elif jx != 0: - csub[${j}] += ${jx}*bv; - % endif - % elif jx != 0: - csub[${j}] += ${jx}*bv; - % endif - ## - % if kx == alix[j]: - store_c(&c[i + ${j}*ldc], csub[${j}]); - % endif - % endfor -% endfor - -## Handle rows of A which are all zero -% for j, jx in enumerate(afix): - % if jx == -1 and beta == 0: - store_c(&c[i + ${j}*ldc], make_zero()); - % elif jx == -1 and beta != 1: - store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); - % endif -% endfor - } -} diff --git a/gimmik/kernels/hip/bstream.mako b/gimmik/kernels/hip/bstream.mako index 10df790..1e7a70b 100644 --- a/gimmik/kernels/hip/bstream.mako +++ b/gimmik/kernels/hip/bstream.mako @@ -1,5 +1,7 @@ <%inherit file='base'/> +<% preload = context.get('preload', False) %> + __global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, @@ -24,17 +26,34 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { ${dtype} bv, csub[${m}]; -## Iterare through the used rows of B +% if preload and beta != 0: +## Preload C values for rows which will receive a non-zero dot product +% for j, jx in enumerate(afix): + % if jx != -1: + % if beta == 1: + csub[${j}] = load_c(&c[i + ${j}*ldc]); + % else: + csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]); + % endif + % endif +% endfor +% endif + +## Iterate through the used rows of B % for kx in bix: bv = load_b(&b[i + ${kx}*ldb]); % for j, jx in enumerate(A[:, kx]): - % if jx != 0 and kx == afix[j]: + % if preload and beta != 0 and jx != 0: + csub[${j}] += ${jx}*bv; + % elif jx != 0 and kx == afix[j]: csub[${j}] = ${jx}*bv; % elif jx != 0: csub[${j}] += ${jx}*bv; % endif ## - % if kx == alix[j] and beta == 0: + % if preload and kx == alix[j]: + store_c(&c[i + ${j}*ldc], csub[${j}]); + % elif kx == alix[j] and beta == 0: store_c(&c[i + ${j}*ldc], csub[${j}]); % elif kx == alix[j] and beta == 1: store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + csub[${j}]); diff --git a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako b/gimmik/kernels/hip/cstream-ksplit-preload-c.mako deleted file mode 100644 index b922666..0000000 --- a/gimmik/kernels/hip/cstream-ksplit-preload-c.mako +++ /dev/null @@ -1,108 +0,0 @@ -<%inherit file='base'/> - -<% -kparts = partition(A, ksplit, by='cols') -cchunks = chunk(range(m), csz) -loaded = set() -%> - -__global__ __launch_bounds__(${blockx*ksplit}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - int i = blockDim.x*blockIdx.x + threadIdx.x; - - ${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp; - __shared__ ${dtype} csub[${ksplit - 1}][${csz}][${blockx}]; - -## Iterate over the row-partitions of C -% for cchunk in cchunks: - ## Iterate over the row-partitions of B - % for bid, kbx in enumerate(kparts): - if (i < n && threadIdx.y == ${bid}) - { - ## Evaluate our partial dot products - % for j in cchunk: - ## Load in any missing parts of B - % for kx in kbx: - % if A[j, kx] != 0 and kx not in loaded: - bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> - % endif - % endfor - <% - nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] - has_dotp = A[j].any() - if nzixs: - first_l_idx, first_kx = nzixs[0] - dotex = f"{A[j, first_kx]}*bv[{first_l_idx}]" - for l_idx, kx in nzixs[1:]: - dotex = f"{dotex} + {A[j, kx]}*bv[{l_idx}]" - else: - dotex = 'make_zero()' - %> - dotp = ${dotex}; - ## Save to a register - % if loop.index % ksplit == bid: - % if beta == 0: - cv[${loop.index // ksplit}] = dotp; - % elif beta == 1 and has_dotp: - cv[${loop.index // ksplit}] = load_c(&c[i + ${j}*ldc]); - cv[${loop.index // ksplit}] += dotp; - % elif has_dotp: - cv[${loop.index // ksplit}] = ${beta}*load_c(&c[i + ${j}*ldc]); - cv[${loop.index // ksplit}] += dotp; - % endif - ## Save to shared memory - % else: - csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; - % endif - % endfor - } - % endfor - __syncthreads(); - ## Iterate over the column-partitions of B - % for bid, kbx in enumerate(kparts): - if (i < n && threadIdx.y == ${bid}) - { - ## Sum and output the final set of dot products - % for j in cchunk: - % if loop.index % ksplit == bid: - <% has_dotp = A[j].any() %> - <% - sum_expr = f"cv[{loop.index // ksplit}]" - for s_idx in range(ksplit - 1): - sum_expr = f"{sum_expr} + csub[{s_idx}][{loop.index}][threadIdx.x]" - %> - % if beta == 0: - dotp = ${sum_expr}; - store_c(&c[i + ${j}*ldc], dotp); - % elif beta == 1 and has_dotp: - dotp = ${sum_expr}; - store_c(&c[i + ${j}*ldc], dotp); - % elif beta != 1 and has_dotp: - dotp = ${sum_expr}; - store_c(&c[i + ${j}*ldc], dotp); - % elif beta != 1: - store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); - % endif - % endif - % endfor - } - % endfor - __syncthreads(); -% endfor -} diff --git a/gimmik/kernels/hip/cstream-ksplit.mako b/gimmik/kernels/hip/cstream-ksplit.mako index 5fc6711..12c59ba 100644 --- a/gimmik/kernels/hip/cstream-ksplit.mako +++ b/gimmik/kernels/hip/cstream-ksplit.mako @@ -4,6 +4,7 @@ kparts = partition(A, ksplit, by='cols') cchunks = chunk(range(m), csz) loaded = set() +preload = context.get('preload', False) %> __global__ __launch_bounds__(${blockx*ksplit}) void @@ -45,6 +46,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endfor <% nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] + has_dotp = A[j].any() if nzixs: first_l_idx, first_kx = nzixs[0] dotex = f"{A[j, first_kx]}*bv[{first_l_idx}]" @@ -56,7 +58,17 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) dotp = ${dotex}; ## Save to a register % if loop.index % ksplit == bid: + % if preload and beta == 0: cv[${loop.index // ksplit}] = dotp; + % elif preload and beta == 1 and has_dotp: + cv[${loop.index // ksplit}] = load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif preload and has_dotp: + cv[${loop.index // ksplit}] = ${beta}*load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif not preload: + cv[${loop.index // ksplit}] = dotp; + % endif ## Save to shared memory % else: csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; @@ -72,17 +84,31 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Sum and output the final set of dot products % for j in cchunk: % if loop.index % ksplit == bid: + <% has_dotp = A[j].any() %> <% sum_expr = f"cv[{loop.index // ksplit}]" for s_idx in range(ksplit - 1): sum_expr = f"{sum_expr} + csub[{s_idx}][{loop.index}][threadIdx.x]" %> + % if preload and beta == 0: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta == 1 and has_dotp: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta != 1 and has_dotp: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta != 1: + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); + % elif beta == 0: dotp = ${sum_expr}; - % if beta == 0: store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: + dotp = ${sum_expr}; store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: + dotp = ${sum_expr}; store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endif diff --git a/gimmik/kernels/hip/cstream-preload-c.mako b/gimmik/kernels/hip/cstream-preload-c.mako deleted file mode 100644 index c9a83b0..0000000 --- a/gimmik/kernels/hip/cstream-preload-c.mako +++ /dev/null @@ -1,61 +0,0 @@ -<%inherit file='base'/> - -<% ksplit = 2 if m < 36 else 1 %> - -__global__ __launch_bounds__(${blockx}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ - % if width > 1: - n = (n + ${width} - 1) / ${width}; - ldb /= ${width}; - ldc /= ${width}; - % endif -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${-(-n // width)}; - const ${'long long' if k*ldb >= width*2**31 else 'int'} ldb = ${ldb // width}; - const ${'long long' if m*ldc >= width*2**31 else 'int'} ldc = ${ldc // width}; -% endif - const int i = blockDim.x*blockIdx.x + threadIdx.x; - ${dtype} dotp; - - if (i < n) - { -% for j, jx in enumerate(A): - <% - nzixs = [kx for kx, val in enumerate(jx) if val != 0] - if nzixs: - first_kx = nzixs[0] - dotex = f"{jx[first_kx]}*b[i + {first_kx}*ldb]" - for kx in nzixs[1:]: - dotex = f"{dotex} + {jx[kx]}*b[i + {kx}*ldb]" - else: - dotex = 'make_zero()' - %> - % if nzixs: - % if beta == 0: - dotp = ${dotex}; - store_c(&c[i + ${j}*ldc], dotp); - % elif beta == 1: - dotp = load_c(&c[i + ${j}*ldc]); - dotp += ${dotex}; - store_c(&c[i + ${j}*ldc], dotp); - % else: - dotp = ${beta}*load_c(&c[i + ${j}*ldc]); - dotp += ${dotex}; - store_c(&c[i + ${j}*ldc], dotp); - % endif - % else: - % if beta == 0: - store_c(&c[i + ${j}*ldc], make_zero()); - % elif beta != 1: - store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); - % endif - % endif -% endfor - } -} diff --git a/gimmik/kernels/hip/cstream.mako b/gimmik/kernels/hip/cstream.mako index 4ea995e..2ee9574 100644 --- a/gimmik/kernels/hip/cstream.mako +++ b/gimmik/kernels/hip/cstream.mako @@ -1,6 +1,8 @@ <%inherit file='base'/> -<% ksplit = 2 if m < 36 else 1 %> +<% +preload = context.get('preload', False) +%> __global__ __launch_bounds__(${blockx}) void % if n is None: @@ -37,10 +39,23 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) dotex = 'make_zero()' %> dotp = ${dotex}; - % if not nzixs: - dotp = make_zero(); - % endif - % if beta == 0: + % if preload and nzixs: + % if beta == 0: + store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1: + dotp = load_c(&c[i + ${j}*ldc]) + dotp; + store_c(&c[i + ${j}*ldc], dotp); + % else: + dotp = ${beta}*load_c(&c[i + ${j}*ldc]) + dotp; + store_c(&c[i + ${j}*ldc], dotp); + % endif + % elif preload: + % if beta == 0: + store_c(&c[i + ${j}*ldc], make_zero()); + % elif beta != 1: + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); + % endif + % elif beta == 0: store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1 and nzixs: store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); From e014e4d40c1745e2b803e0f85ea3ce92da1aee50 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Wed, 24 Jun 2026 23:21:08 -0500 Subject: [PATCH 12/25] Avoid HIP vector operator+= overloads --- gimmik/kernels/hip/base.mako | 6 ------ 1 file changed, 6 deletions(-) diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index 95e622f..d67ee25 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -5,9 +5,6 @@ inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) { return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); } -inline __device__ void operator+=(${dtype} &a, ${dtype} b) -{ a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; } - inline __device__ ${dtype} make_zero() { return make_${dtype}(0, 0, 0, 0); } % elif dtype.endswith('2'): @@ -17,9 +14,6 @@ inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) { return make_${dtype}(a*b.x, a*b.y); } -inline __device__ void operator+=(${dtype} &a, ${dtype} b) -{ a.x += b.x; a.y += b.y; } - inline __device__ ${dtype} make_zero() { return make_${dtype}(0, 0); } % else: From 9dfd0728e9ca903179f17e6026016b89a5df3306 Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 14:51:18 +0800 Subject: [PATCH 13/25] Add f64 MFMA dense kernel for CDNA3 (gfx94x) Add f64 MFMA dense kernel for CDNA3 (gfx94x) Add a dense double-precision GEMM kernel using the CDNA Matrix Cores (__builtin_amdgcn_mfma_f64_16x16x4f64) as an alternative to rocBLAS on the dense path. A is densified, padded and baked into the kernel in fragment order, B is streamed from global memory, C is non-temporal stored, and the epilogue is fully unrolled. Each 64-lane wavefront sweeps 4 consecutive 16-wide N-tiles (64 cols per block), keeping the cols-per-block = blockx contract so the existing grid logic launches it unchanged. Gated on gfx94x and f64 operands with density >= 0.5 and m,k <= 128. Verified offline (py_compile, emission gating, fragment bake, and a numerical emulation matching A @ B to 1.8e-15); on-device accuracy still to be confirmed. --- gimmik/hip.py | 48 +++++++++++++++ gimmik/kernels/hip/mfma-dense.mako | 97 ++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 gimmik/kernels/hip/mfma-dense.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index 860ed6c..f1ca955 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -2,6 +2,8 @@ from gimmik.base import MatMul +import numpy as np + class HIPMatMul(MatMul): platform = 'hip' @@ -39,6 +41,20 @@ def emit(name, args, meta): meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} yield from emit('cstream-ksplit', args, meta) + # Dense f64 GEMM via the CDNA Matrix Cores (MFMA); see mfma-dense.mako. + # Modelled on the NVIDIA DMMA dense path: A is densified + baked in + # Matrix-Core fragment order, B is streamed, C is non-temporal stored. + # Densifying means it only pays off for reasonably dense operands, and + # the MFMA intrinsic is CDNA3-only (gfx94x). + if self._is_cdna3(gcn_arch) and self._mfma_dense_ok(dsize): + blkx = 64 + a_hex, m_tiles, k_tiles = self._mfma_dense_bake() + args = {'blockx': blkx, 'a_hex': a_hex, + 'm_tiles': m_tiles, 'k_tiles': k_tiles} + meta = {'block': (blkx, 1, 1), + 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-x{blkx}'} + yield from emit('mfma-dense', args, meta) + # Only emit tuned variants on architectures they have been validated for. base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64: @@ -142,6 +158,38 @@ def emit(name, args, meta): } | wmeta yield from emit('cstream-ksplit-preload-c', args, meta) + @staticmethod + def _is_cdna3(gcn_arch): + base = gcn_arch.split(':', 1)[0] if gcn_arch else None + return base in {'gfx940', 'gfx941', 'gfx942'} + + def _mfma_dense_ok(self, dsize): + # f64 Matrix-Core only; the densified path is only worthwhile when A is + # reasonably dense, and is bounded so the baked A array stays small. + if dsize != 8 or self.m > 128 or self.k > 128: + return False + density = np.count_nonzero(self.A) / self.A.size + return density >= 0.5 + + def _mfma_dense_bake(self): + # Densify, pad and reorder A into v_mfma_f64_16x16x4 fragment order: + # Ag[(mt*k_tiles + kt)*64 + lane] + # = A_pad[mt*16 + lane%16][kt*4 + lane//16] + # i.e. with lane = g*16 + p, operand A wants i = p, kk = g. + m, k = self.A.shape + m_tiles = -(-m // 16) + k_tiles = -(-k // 4) + a_pad = np.zeros((m_tiles*16, k_tiles*4), dtype=np.float64) + a_pad[:m, :k] = self.A + a_hex = [] + for mt in range(m_tiles): + for kt in range(k_tiles): + for lane in range(64): + i = mt*16 + (lane % 16) + kk = kt*4 + (lane // 16) + a_hex.append(float(a_pad[i, kk]).hex()) + return a_hex, m_tiles, k_tiles + def _process_meta(self, meta): if self.n is not None: div = meta['block'][0]*meta['width'] diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako new file mode 100644 index 0000000..3f1b877 --- /dev/null +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -0,0 +1,97 @@ +<%inherit file='base'/> +## +## Dense double-precision GEMM using the CDNA Matrix Cores (MFMA). +## +## Modelled on the NVIDIA PTX "dmma-astream" kernel (mma.sync.aligned.m8n8k4): +## the constant operand A is densified, padded and *baked* into the kernel in +## Matrix-Core fragment order, B is streamed straight from global memory and C +## is written with non-temporal stores. The NVIDIA tensor-core tile is +## m8n8k4; the CDNA f64 Matrix-Core tile is m16n16k4, computed by +## __builtin_amdgcn_mfma_f64_16x16x4f64 over a single 64-lane wavefront. +## +## --------------------------------------------------------------------------- +## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), with +## g = lane / 16 (0..3) p = lane % 16 (0..15) +## A (16x4, 1 reg/lane): A[i][kk] with i = p, kk = g +## B (4x16, 1 reg/lane): B[kk][j] with kk = g, j = p +## C/D(16x16, 4 reg/lane): D[i][j] with j = p, i = 4*g + reg +## The baked A array (built in hip.py) uses EXACTLY this mapping: +## Ag[(mt*k_tiles + kt)*64 + lane] = A_padded[mt*16 + (lane%16)][kt*4 + (lane//16)] +## If an on-device accuracy check fails, this single mapping (here + in +## _mfma_dense_bake) is the place to revisit -- A, B and C are all derived +## from it consistently. +## --------------------------------------------------------------------------- +## +<% + tiles = blockx // 16 # 16-wide N-tiles swept per wavefront (= cols/block / 16) +%> +typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); + +__device__ static const ${dtype} ${kname}_Ag[${m_tiles * k_tiles * 64}] = { + ${', '.join(a_hex)} +}; + +__global__ __launch_bounds__(${blockx}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${n}; + const ${'long long' if k * ldb >= 2**31 else 'int'} ldb = ${ldb}; + const ${'long long' if m * ldc >= 2**31 else 'int'} ldc = ${ldc}; +% endif + const int lane = threadIdx.x; + const int g = lane / 16; // 0..3 + const int p = lane % 16; // 0..15 + const int col_base = ${blockx}*blockIdx.x; + + ${dtype} a; +% for mt in range(m_tiles): +% for t in range(tiles): + gimmik_f64x4 acc_${mt}_${t} = {0.0, 0.0, 0.0, 0.0}; +% endfor +% endfor + +% for kt in range(k_tiles): +<% + krow_guard = (kt + 1)*4 > k +%> +% for t in range(tiles): + ${dtype} bv_${t}; + { + const int col = col_base + ${t*16} + p; + const int krow = ${kt*4} + g; + bv_${t} = (col < n${' && krow < %d' % k if krow_guard else ''}) ? b[krow*ldb + col] : (${dtype})0; + } +% endfor +% for mt in range(m_tiles): + a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; +% for t in range(tiles): + acc_${mt}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv_${t}, acc_${mt}_${t}, 0, 0, 0); +% endfor +% endfor +% endfor + +% for mt in range(m_tiles): +% for t in range(tiles): +% for reg in range(4): + { + const int row = ${mt*16 + reg} + 4*g; + const int col = col_base + ${t*16} + p; + if (row < ${m} && col < n) +% if beta == 0: + store_c(&c[row*ldc + col], acc_${mt}_${t}[${reg}]); +% elif beta == 1: + store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${mt}_${t}[${reg}])); +% else: + store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${mt}_${t}[${reg}])); +% endif + } +% endfor +% endfor +% endfor +} From 6d237ef66603999afa213fad2e8aced8d2d825b5 Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 14:54:23 +0800 Subject: [PATCH 14/25] Update mfma-dense.mako --- gimmik/kernels/hip/mfma-dense.mako | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako index 3f1b877..a0bf6fb 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -50,6 +50,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) const int col_base = ${blockx}*blockIdx.x; ${dtype} a; +% for t in range(tiles): + ${dtype} bv_${t}; +% endfor % for mt in range(m_tiles): % for t in range(tiles): gimmik_f64x4 acc_${mt}_${t} = {0.0, 0.0, 0.0, 0.0}; @@ -61,7 +64,6 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) krow_guard = (kt + 1)*4 > k %> % for t in range(tiles): - ${dtype} bv_${t}; { const int col = col_base + ${t*16} + p; const int krow = ${kt*4} + g; From f6bc30882c7fe8c36da58d4763a2b32089b13a49 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Thu, 25 Jun 2026 02:02:37 -0500 Subject: [PATCH 15/25] Prune HIP tuned variants to 12 Reduce the HIP tuned kernel search space from 28 variants to 12 and order the remaining variants to try common winners earlier. --- gimmik/hip.py | 54 ++------------------------------------------------- 1 file changed, 2 insertions(+), 52 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 1fceb01..b55e3fe 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -24,35 +24,17 @@ def emit_preload(name, args, meta): blkx = self.basemeta['block'][0] - # B loading, C streaming kernel - yield from emit('cstream', {'blockx': blkx}, {}) - - # B streaming, C accumulation kernel - yield from emit('bstream', {'blockx': blkx}, {}) - - # Four-way m-split B streaming, C accumulation kernel - ms, bsz, blkx = 4, 24, 64 - args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} - yield from emit('bstream-msplit', args, meta) - - # Two-way k-split B loading, C streaming kernel - ks, csz, blkx = 2, 24, 64 - args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} - yield from emit('cstream-ksplit', args, meta) - # Only emit tuned variants on architectures they have been validated for. base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64: return # Tuned HIP variants - msplits, ksplits = [4, 8], [2, 4] + msplits, ksplits = [8, 4], [4, 2] bsz, csz, blkx = 8, 8, 64 widths = [1] if self.aligne is not None and self.aligne % 2 == 0: - widths.append(2) + widths.insert(0, 2) for width in widths: wargs = ({'dtype': f'{dtype}{width}', 'width': width} @@ -60,17 +42,6 @@ def emit_preload(name, args, meta): wmeta = {'width': width} if width > 1 else {} wpfx = f'w{width}-' if width > 1 else '' - # B loading, C streaming kernel - args = {'blockx': blkx} | wargs - meta = {'block': (blkx, 1, 1), - 'desc': f'cstream/{wpfx}x{blkx}'} | wmeta - yield from emit('cstream', args, meta) - - # B streaming, C accumulation kernel - meta = {'block': (blkx, 1, 1), - 'desc': f'bstream/{wpfx}x{blkx}'} | wmeta - yield from emit('bstream', args, meta) - for ms in msplits: # m-split B streaming, C accumulation kernel args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs @@ -81,27 +52,6 @@ def emit_preload(name, args, meta): } | wmeta yield from emit('bstream-msplit', args, meta) - for ks in ksplits: - # k-split B loading, C streaming kernel - args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs - shared = (ks - 1)*csz*blkx*dsize*width - meta = { - 'block': (blkx, ks, 1), 'shared': shared, - 'desc': f'cstream-ksplit/{wpfx}k{ks}-c{csz}-x{blkx}' - } | wmeta - yield from emit('cstream-ksplit', args, meta) - - # B loading, C preloading, C streaming kernel - args = {'blockx': blkx} | wargs - meta = {'block': (blkx, 1, 1), - 'desc': f'cstream-preload-c/{wpfx}x{blkx}'} | wmeta - yield from emit_preload('cstream', args, meta) - - # B streaming, C preloading, C accumulation kernel - meta = {'block': (blkx, 1, 1), - 'desc': f'bstream-preload-c/{wpfx}x{blkx}'} | wmeta - yield from emit_preload('bstream', args, meta) - for ms in msplits: # m-split B streaming, C preloading, C accumulation kernel args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs From 6689a9c662fe17a3ba812f8bbd4d70eb9b57a5e0 Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 15:03:09 +0800 Subject: [PATCH 16/25] Update mfma-dense.mako --- gimmik/kernels/hip/mfma-dense.mako | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako index a0bf6fb..f0466d4 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -14,7 +14,7 @@ ## g = lane / 16 (0..3) p = lane % 16 (0..15) ## A (16x4, 1 reg/lane): A[i][kk] with i = p, kk = g ## B (4x16, 1 reg/lane): B[kk][j] with kk = g, j = p -## C/D(16x16, 4 reg/lane): D[i][j] with j = p, i = 4*g + reg +## C/D(16x16, 4 reg/lane): D[i][j] with j = p, i = 4*reg + g ## The baked A array (built in hip.py) uses EXACTLY this mapping: ## Ag[(mt*k_tiles + kt)*64 + lane] = A_padded[mt*16 + (lane%16)][kt*4 + (lane//16)] ## If an on-device accuracy check fails, this single mapping (here + in @@ -82,7 +82,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for t in range(tiles): % for reg in range(4): { - const int row = ${mt*16 + reg} + 4*g; + const int row = ${mt*16 + 4*reg} + g; const int col = col_base + ${t*16} + p; if (row < ${m} && col < n) % if beta == 0: From a3aee45280334d46dbfb578a03a90494f64f2041 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Thu, 25 Jun 2026 02:19:05 -0500 Subject: [PATCH 17/25] Remove HIP variant arch gate --- gimmik/hip.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index b55e3fe..9365938 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -24,11 +24,6 @@ def emit_preload(name, args, meta): blkx = self.basemeta['block'][0] - # Only emit tuned variants on architectures they have been validated for. - base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None - if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64: - return - # Tuned HIP variants msplits, ksplits = [8, 4], [4, 2] bsz, csz, blkx = 8, 8, 64 From b521427c3535fbf5000c77b1476a52673a4d9c46 Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 15:34:00 +0800 Subject: [PATCH 18/25] Update hip.py --- gimmik/hip.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index f1ca955..3e09a00 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -164,12 +164,14 @@ def _is_cdna3(gcn_arch): return base in {'gfx940', 'gfx941', 'gfx942'} def _mfma_dense_ok(self, dsize): - # f64 Matrix-Core only; the densified path is only worthwhile when A is - # reasonably dense, and is bounded so the baked A array stays small. - if dsize != 8 or self.m > 128 or self.k > 128: - return False - density = np.count_nonzero(self.A) / self.A.size - return density >= 0.5 + # f64 Matrix Cores only (that is the only hard requirement of the + # mfma_f64_16x16x4 instruction). The kernel densifies A and is left + # for the autotuner to accept or reject on speed; the earlier + # m,k <= 128 and density >= 0.5 gates were too strict and hid it from + # real PyFR tet operators. Large m increases register pressure (each + # wavefront keeps m_tiles*4 v4f64 accumulators live) -> m-splitting is + # the natural follow-up if that becomes the bottleneck. + return dsize == 8 def _mfma_dense_bake(self): # Densify, pad and reorder A into v_mfma_f64_16x16x4 fragment order: From 3390912bc9048377540cf1dd68448b7ce567933f Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 16:14:17 +0800 Subject: [PATCH 19/25] Add m-splitting and zero-tile skipping to MFMA dense kernel The dense MFMA kernel held all m_tiles*4 v4f64 accumulators live in a single wavefront, which on larger operators meant very high register pressure (~350 VGPR), low occupancy, and poor latency hiding on these bandwidth-bound problems. It also computed every tile of the densified A, including all-zero blocks. Two knobs address both. m-splitting: msplit wavefronts per block, placed in block.y so block.x stays 64 (one wavefront = the cols-per-block grid contract is unchanged). Each wavefront now owns ceil(m_tiles/msplit) m-tiles, so it keeps only ceil(m_tiles/msplit)*4 accumulators live. For msplit > 1 the B tile is staged once into LDS with a synchronous cooperative copy and shared across the block, so B is not re-read per wavefront. Per-warp work is emitted as compile-time blocks (if threadIdx.y == w) so accumulators are scope-local and registers are reused across branches. msplit = 1 keeps the original direct, no-LDS path. Zero-tile skipping: _mfma_dense_bake now also returns amask, marking 16x4 A-tiles that contain a non-zero. All-zero tiles skip their MMA, and on the direct path a fully-zero k-tile also skips its B load. hip.py emits mfma-dense for msplit in {1, 2, 4} (capped at m_tiles), with shared = k_pad*64*dsize for the LDS variants. The accumulator/store path uses the accumulator for every m-tile (inactive ones stay zero), so beta is handled uniformly. Verified offline: the generator emits the expected s1/s2/s4 variants with correct block/shared/grid, and a numerical emulation of both paths reproduces A @ B (with beta) to <= 2.5e-14 for dense, structured-sparse, large-m, and msplit 1/2/4 cases. On-device accuracy and performance still to be benchmarked. --- gimmik/hip.py | 32 ++++++-- gimmik/kernels/hip/mfma-dense.mako | 124 +++++++++++++++++++++-------- 2 files changed, 118 insertions(+), 38 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 3e09a00..64422f0 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -48,12 +48,17 @@ def emit(name, args, meta): # the MFMA intrinsic is CDNA3-only (gfx94x). if self._is_cdna3(gcn_arch) and self._mfma_dense_ok(dsize): blkx = 64 - a_hex, m_tiles, k_tiles = self._mfma_dense_bake() - args = {'blockx': blkx, 'a_hex': a_hex, - 'm_tiles': m_tiles, 'k_tiles': k_tiles} - meta = {'block': (blkx, 1, 1), - 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-x{blkx}'} - yield from emit('mfma-dense', args, meta) + a_hex, m_tiles, k_tiles, amask = self._mfma_dense_bake() + k_pad = k_tiles*4 + for ms in self._mfma_msplits(m_tiles): + # msplit goes in block.y (cf. bstream-msplit) so block.x stays + # 64 = one wavefront = the cols-per-block grid contract. + shared = k_pad*blkx*dsize if ms > 1 else 0 + args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, + 'k_tiles': k_tiles, 'amask': amask, 'msplit': ms} + meta = {'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-s{ms}-x{blkx}'} + yield from emit('mfma-dense', args, meta) # Only emit tuned variants on architectures they have been validated for. base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None @@ -173,11 +178,22 @@ def _mfma_dense_ok(self, dsize): # the natural follow-up if that becomes the bottleneck. return dsize == 8 + def _mfma_msplits(self, m_tiles): + # m-split factors to offer (placed in block.y). Each wavefront keeps + # m_tiles/msplit * 4 v4f64 accumulators live, so splitting m lowers + # register pressure / raises occupancy on large-m operators. msplit=1 + # is the direct (no-LDS) path; msplit>1 stages B once in LDS and shares + # it across the block (so B is not re-read per wavefront). + return [ms for ms in (1, 2, 4) if ms == 1 or ms <= m_tiles] + def _mfma_dense_bake(self): # Densify, pad and reorder A into v_mfma_f64_16x16x4 fragment order: # Ag[(mt*k_tiles + kt)*64 + lane] # = A_pad[mt*16 + lane%16][kt*4 + lane//16] # i.e. with lane = g*16 + p, operand A wants i = p, kk = g. + # amask[mt][kt] flags 16x4 A-tiles that contain a non-zero, so the + # kernel can skip the MMA (and, on the direct path, the B load) for + # all-zero tiles -- structural zero-tile skipping. m, k = self.A.shape m_tiles = -(-m // 16) k_tiles = -(-k // 4) @@ -190,7 +206,9 @@ def _mfma_dense_bake(self): i = mt*16 + (lane % 16) kk = kt*4 + (lane // 16) a_hex.append(float(a_pad[i, kk]).hex()) - return a_hex, m_tiles, k_tiles + amask = [[bool(np.any(a_pad[mt*16:mt*16+16, kt*4:kt*4+4])) + for kt in range(k_tiles)] for mt in range(m_tiles)] + return a_hex, m_tiles, k_tiles, amask def _process_meta(self, meta): if self.n is not None: diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako index f0466d4..c51b75c 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -1,29 +1,31 @@ <%inherit file='base'/> ## -## Dense double-precision GEMM using the CDNA Matrix Cores (MFMA). +## Dense double-precision GEMM on the CDNA Matrix Cores (MFMA). ## -## Modelled on the NVIDIA PTX "dmma-astream" kernel (mma.sync.aligned.m8n8k4): -## the constant operand A is densified, padded and *baked* into the kernel in -## Matrix-Core fragment order, B is streamed straight from global memory and C -## is written with non-temporal stores. The NVIDIA tensor-core tile is -## m8n8k4; the CDNA f64 Matrix-Core tile is m16n16k4, computed by -## __builtin_amdgcn_mfma_f64_16x16x4f64 over a single 64-lane wavefront. +## Mirrors the NVIDIA PTX dense path: the constant operand A is densified, +## padded and baked into the kernel in Matrix-Core fragment order, B is +## streamed, and the epilogue is fully unrolled. Two knobs over v1: +## * zero-tile skipping -- amask[mt][kt] marks 16x4 A-tiles with a non-zero; +## all-zero tiles skip their MMA (and, on the direct path, the B load). +## * m-splitting -- msplit wavefronts per block (in block.y) each own a +## slice of the m-tiles, lowering per-wavefront accumulator pressure. +## For msplit>1 the B tile is staged once in LDS and shared by the whole +## block, so B is not re-read per wavefront. ## -## --------------------------------------------------------------------------- -## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), with -## g = lane / 16 (0..3) p = lane % 16 (0..15) -## A (16x4, 1 reg/lane): A[i][kk] with i = p, kk = g -## B (4x16, 1 reg/lane): B[kk][j] with kk = g, j = p -## C/D(16x16, 4 reg/lane): D[i][j] with j = p, i = 4*reg + g -## The baked A array (built in hip.py) uses EXACTLY this mapping: -## Ag[(mt*k_tiles + kt)*64 + lane] = A_padded[mt*16 + (lane%16)][kt*4 + (lane//16)] -## If an on-device accuracy check fails, this single mapping (here + in -## _mfma_dense_bake) is the place to revisit -- A, B and C are all derived -## from it consistently. -## --------------------------------------------------------------------------- +## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), g=lane/16, p=lane%16: +## A (16x4 ): A[i][kk] i=p, kk=g (1 reg/lane) +## B (4x16 ): B[kk][j] kk=g, j=p (1 reg/lane) +## C/D(16x16): D[i][j] j=p, i=4*reg + g (v4f64, 4 reg/lane) +## Bake: Ag[(mt*k_tiles+kt)*64 + lane] = A_pad[mt*16 + lane%16][kt*4 + lane//16] ## <% - tiles = blockx // 16 # 16-wide N-tiles swept per wavefront (= cols/block / 16) + tiles = blockx // 16 # 16-wide N-tiles per wavefront + k_pad = k_tiles * 4 + active_kt = [kt for kt in range(k_tiles) + if any(amask[mt][kt] for mt in range(m_tiles))] + mtpg = -(-m_tiles // msplit) # m-tiles per wavefront + warp_mts = [[mt for mt in range(w*mtpg, min((w+1)*mtpg, m_tiles))] + for w in range(msplit)] %> typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); @@ -31,7 +33,7 @@ __device__ static const ${dtype} ${kname}_Ag[${m_tiles * k_tiles * 64}] = { ${', '.join(a_hex)} }; -__global__ __launch_bounds__(${blockx}) void +__global__ __launch_bounds__(${blockx * msplit}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, @@ -45,10 +47,12 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) const ${'long long' if m * ldc >= 2**31 else 'int'} ldc = ${ldc}; % endif const int lane = threadIdx.x; - const int g = lane / 16; // 0..3 - const int p = lane % 16; // 0..15 + const int g = lane / 16; + const int p = lane % 16; const int col_base = ${blockx}*blockIdx.x; +% if msplit == 1: + ## ---- direct path: single wavefront, B straight from global ---- ${dtype} a; % for t in range(tiles): ${dtype} bv_${t}; @@ -58,11 +62,8 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) gimmik_f64x4 acc_${mt}_${t} = {0.0, 0.0, 0.0, 0.0}; % endfor % endfor - -% for kt in range(k_tiles): -<% - krow_guard = (kt + 1)*4 > k -%> +% for kt in active_kt: +<% krow_guard = (kt + 1)*4 > k %> % for t in range(tiles): { const int col = col_base + ${t*16} + p; @@ -71,13 +72,14 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) } % endfor % for mt in range(m_tiles): +% if amask[mt][kt]: a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; -% for t in range(tiles): +% for t in range(tiles): acc_${mt}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv_${t}, acc_${mt}_${t}, 0, 0, 0); -% endfor +% endfor +% endif % endfor % endfor - % for mt in range(m_tiles): % for t in range(tiles): % for reg in range(4): @@ -96,4 +98,64 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endfor % endfor % endfor + +% else: + ## ---- m-split path: stage B in LDS once, share across msplit wavefronts ---- + __shared__ ${dtype} ${kname}_Bs[${k_pad * blockx}]; + const int tid = threadIdx.y*${blockx} + threadIdx.x; + for (int idx = tid; idx < ${k_pad * blockx}; idx += ${blockx * msplit}) + { + const int krow = idx / ${blockx}; + const int cc = idx % ${blockx}; + const int col = col_base + cc; + ${kname}_Bs[idx] = (krow < ${k} && col < n) ? b[krow*ldb + col] : (${dtype})0; + } + __syncthreads(); + +% for w in range(msplit): +<% mts = warp_mts[w] %> +% if mts: + if (threadIdx.y == ${w}) + { + ${dtype} a, ${', '.join('bv_%d' % t for t in range(tiles))}; +% for j in range(len(mts)): +% for t in range(tiles): + gimmik_f64x4 acc_${j}_${t} = {0.0, 0.0, 0.0, 0.0}; +% endfor +% endfor +% for kt in active_kt: +% for t in range(tiles): + bv_${t} = ${kname}_Bs[(${kt*4} + g)*${blockx} + ${t*16} + p]; +% endfor +% for j, mt in enumerate(mts): +% if amask[mt][kt]: + a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; +% for t in range(tiles): + acc_${j}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv_${t}, acc_${j}_${t}, 0, 0, 0); +% endfor +% endif +% endfor +% endfor +% for j, mt in enumerate(mts): +% for t in range(tiles): +% for reg in range(4): + { + const int row = ${mt*16 + 4*reg} + g; + const int col = col_base + ${t*16} + p; + if (row < ${m} && col < n) +% if beta == 0: + store_c(&c[row*ldc + col], acc_${j}_${t}[${reg}]); +% elif beta == 1: + store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${j}_${t}[${reg}])); +% else: + store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${j}_${t}[${reg}])); +% endif + } +% endfor +% endfor +% endfor + } +% endif +% endfor +% endif } From 99deb2e1ec61a3ef220ec92ca5e78bf2435a59cc Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 16:35:07 +0800 Subject: [PATCH 20/25] Add software-pipelined (double-buffered B) MFMA dense variant Add mfma-dense-pipe, a variant of the dense MFMA kernel's direct path that overlaps global B loads with Matrix-Core work. B for the next k-tile is loaded into a second register buffer before the current k-tile's MFMAs are issued, with the two buffers alternated per k-tile, so the global-load latency hides behind the MFMA pipeline. This targets the B-from-global, latency-bound case. The maths is identical to the msplit=1 direct path: A densified and baked in fragment order, B streamed, C non-temporal stored, epilogue fully unrolled, and zero-tile skipping via amask. It is emitted as a separate kernel so the prefetch restructuring does not complicate the main mfma-dense template. hip.py emits mfma-dense-pipe once per dense f64 operand on gfx94x, alongside the existing s1/s2/s4 mfma-dense variants. Verified offline: the generator emits the expected variant, and a numerical emulation of the pipelined path reproduces A @ B (with beta) to <= 1.4e-14 across dense, large-m, structured-sparse, single-k-tile, and beta cases. On-device performance still to be benchmarked. --- gimmik/hip.py | 8 ++ gimmik/kernels/hip/mfma-dense-pipe.mako | 112 ++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 gimmik/kernels/hip/mfma-dense-pipe.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index 64422f0..9baf1f8 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -60,6 +60,14 @@ def emit(name, args, meta): 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-s{ms}-x{blkx}'} yield from emit('mfma-dense', args, meta) + # Software-pipelined (double-buffered B) direct variant: prefetch + # next k-tile's B while the current k-tile's MFMAs run. + args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, + 'k_tiles': k_tiles, 'amask': amask} + meta = {'block': (blkx, 1, 1), 'shared': 0, + 'desc': f'mfma-dense-pipe/m{m_tiles}-k{k_tiles}-x{blkx}'} + yield from emit('mfma-dense-pipe', args, meta) + # Only emit tuned variants on architectures they have been validated for. base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64: diff --git a/gimmik/kernels/hip/mfma-dense-pipe.mako b/gimmik/kernels/hip/mfma-dense-pipe.mako new file mode 100644 index 0000000..c35bea0 --- /dev/null +++ b/gimmik/kernels/hip/mfma-dense-pipe.mako @@ -0,0 +1,112 @@ +<%inherit file='base'/> +## +## Dense f64 GEMM on the CDNA Matrix Cores (MFMA) -- software-pipelined variant. +## +## Same maths as mfma-dense (msplit=1 direct path): A densified + baked in +## fragment order, B streamed from global, C non-temporal stored, epilogue +## fully unrolled, zero-tile skipping via amask. The only difference: B for +## the NEXT k-tile is issued before the MFMAs of the CURRENT k-tile, so the +## global-load latency overlaps the Matrix-Core work (double-buffered B in +## registers, buffers 0/1 alternated per k-tile). +## +## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), g=lane/16, p=lane%16: +## A (16x4 ): A[i][kk] i=p, kk=g +## B (4x16 ): B[kk][j] kk=g, j=p +## C/D(16x16): D[i][j] j=p, i=4*reg + g (v4f64) +## +<% + tiles = blockx // 16 + active_kt = [kt for kt in range(k_tiles) + if any(amask[mt][kt] for mt in range(m_tiles))] +%> +typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); + +__device__ static const ${dtype} ${kname}_Ag[${m_tiles * k_tiles * 64}] = { + ${', '.join(a_hex)} +}; + +__global__ __launch_bounds__(${blockx}) void +% if n is None: +${kname}(int n, + const ${dtype}* __restrict__ b, int ldb, + ${dtype}* __restrict__ c, int ldc) +{ +% else: +${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${n}; + const ${'long long' if k * ldb >= 2**31 else 'int'} ldb = ${ldb}; + const ${'long long' if m * ldc >= 2**31 else 'int'} ldc = ${ldc}; +% endif + const int lane = threadIdx.x; + const int g = lane / 16; + const int p = lane % 16; + const int col_base = ${blockx}*blockIdx.x; + + ${dtype} a; + ${dtype} ${', '.join('bv0_%d' % t for t in range(tiles))}; + ${dtype} ${', '.join('bv1_%d' % t for t in range(tiles))}; +% for mt in range(m_tiles): +% for t in range(tiles): + gimmik_f64x4 acc_${mt}_${t} = {0.0, 0.0, 0.0, 0.0}; +% endfor +% endfor + +% if active_kt: +<% kt0 = active_kt[0]; g0 = (kt0 + 1)*4 > k %> + // prefetch the first k-tile into buffer 0 +% for t in range(tiles): + { + const int col = col_base + ${t*16} + p; + const int krow = ${kt0*4} + g; + bv0_${t} = (col < n${' && krow < %d' % k if g0 else ''}) ? b[krow*ldb + col] : (${dtype})0; + } +% endfor + +% for i, kt in enumerate(active_kt): +<% + cur = i % 2 + nxt = (i + 1) % 2 + has_next = i + 1 < len(active_kt) +%> +% if has_next: +<% knext = active_kt[i+1]; gN = (knext + 1)*4 > k %> + // prefetch k-tile ${knext} into buffer ${nxt} +% for t in range(tiles): + { + const int col = col_base + ${t*16} + p; + const int krow = ${knext*4} + g; + bv${nxt}_${t} = (col < n${' && krow < %d' % k if gN else ''}) ? b[krow*ldb + col] : (${dtype})0; + } +% endfor +% endif +% for mt in range(m_tiles): +% if amask[mt][kt]: + a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; +% for t in range(tiles): + acc_${mt}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv${cur}_${t}, acc_${mt}_${t}, 0, 0, 0); +% endfor +% endif +% endfor +% endfor +% endif + +% for mt in range(m_tiles): +% for t in range(tiles): +% for reg in range(4): + { + const int row = ${mt*16 + 4*reg} + g; + const int col = col_base + ${t*16} + p; + if (row < ${m} && col < n) +% if beta == 0: + store_c(&c[row*ldc + col], acc_${mt}_${t}[${reg}]); +% elif beta == 1: + store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${mt}_${t}[${reg}])); +% else: + store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${mt}_${t}[${reg}])); +% endif + } +% endfor +% endfor +% endfor +} From 1e554de7d4cd85823c6a947c0bf309314491c04a Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 17:35:46 +0800 Subject: [PATCH 21/25] Cut B traffic in MFMA m-split path with bix-compacted, vectorized LDS load The m-split MFMA dense path staged the full padded B tile into LDS with a scalar cooperative copy, reading every k-row even when A uses only a subset. On these bandwidth-bound operators that wastes the resource that actually limits us. Two changes reduce and speed up the B read. bix compaction: only the k-rows A actually uses (sorted(self.bix)) are read from global into LDS. Hole rows and the padded tail are zeroed, because A is zero there and the MMA needs a finite (not NaN) operand -- an all-zero row left uninitialised would make 0*NaN corrupt the whole output row. The n-tail needs no such care: an MFMA output column depends only on its own B column, which is already store-guarded. The zero pass (and its extra __syncthreads) is emitted only when bix does not already cover every k-row read and k is a multiple of 4. Vectorized load: when the layout is 2-aligned (aligne % 2 == 0) the global read is issued as a 16-byte f64x2, with Bs declared __align__(16); otherwise it falls back to a scalar copy, with the odd-n column handled separately. hip.py passes bix_rows and the vec2 flag into the m-split variants; the direct (msplit=1) path is unchanged. Verified offline: a numerical emulation of the new staging reproduces A @ B (with beta) to <= 1.4e-14 for operands with zero k-columns (bix 17/24), fully dense (40/40), single active k-tile, msplit 2/4, and the scalar fallback. On-device performance still to be benchmarked; this only reduces the B-read side, so the gain is expected to be modest while C writes dominate. --- gimmik/hip.py | 5 +++- gimmik/kernels/hip/mfma-dense.mako | 42 +++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 9baf1f8..8673d37 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -50,12 +50,15 @@ def emit(name, args, meta): blkx = 64 a_hex, m_tiles, k_tiles, amask = self._mfma_dense_bake() k_pad = k_tiles*4 + bix_rows = sorted(self.bix) # k-rows A actually uses + vec2 = self.aligne is not None and self.aligne % 2 == 0 for ms in self._mfma_msplits(m_tiles): # msplit goes in block.y (cf. bstream-msplit) so block.x stays # 64 = one wavefront = the cols-per-block grid contract. shared = k_pad*blkx*dsize if ms > 1 else 0 args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, - 'k_tiles': k_tiles, 'amask': amask, 'msplit': ms} + 'k_tiles': k_tiles, 'amask': amask, 'msplit': ms, + 'bix_rows': bix_rows, 'vec2': vec2} meta = {'block': (blkx, ms, 1), 'shared': shared, 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-s{ms}-x{blkx}'} yield from emit('mfma-dense', args, meta) diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako index c51b75c..94222d1 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -28,6 +28,7 @@ for w in range(msplit)] %> typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); +typedef ${dtype} gimmik_f64x2 __attribute__((ext_vector_type(2))); __device__ static const ${dtype} ${kname}_Ag[${m_tiles * k_tiles * 64}] = { ${', '.join(a_hex)} @@ -101,15 +102,48 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % else: ## ---- m-split path: stage B in LDS once, share across msplit wavefronts ---- - __shared__ ${dtype} ${kname}_Bs[${k_pad * blockx}]; + ## Only the k-rows A actually uses (bix_rows) are read from global; holes + ## and the padded tail are zeroed (A is 0 there, so an MMA against 0 needs a + ## finite -- not NaN -- operand). The global read is vectorized as f64x2 + ## when the layout is 2-aligned. + __shared__ __align__(16) ${dtype} ${kname}_Bs[${k_pad * blockx}]; const int tid = threadIdx.y*${blockx} + threadIdx.x; - for (int idx = tid; idx < ${k_pad * blockx}; idx += ${blockx * msplit}) +<% + rows_read = sorted({kt*4 + r for kt in active_kt for r in range(4)}) + loaded = set(bix_rows) + need_zero = any(r not in loaded for r in rows_read) + nb = len(bix_rows) + nthreads = blockx * msplit + half = blockx // 2 +%> +% if need_zero: + for (int idx = tid; idx < ${k_pad * blockx}; idx += ${nthreads}) + ${kname}_Bs[idx] = (${dtype})0; + __syncthreads(); +% endif + static const int ${kname}_brows[${nb}] = { ${', '.join(map(str, bix_rows))} }; +% if vec2: + for (int idx = tid; idx < ${nb * half}; idx += ${nthreads}) { - const int krow = idx / ${blockx}; + const int krow = ${kname}_brows[idx / ${half}]; + const int cc = (idx % ${half}) * 2; + const int col = col_base + cc; + if (col + 1 < n) + *(gimmik_f64x2*)&${kname}_Bs[krow*${blockx} + cc] = + *(const gimmik_f64x2*)&b[krow*ldb + col]; + else if (col < n) + ${kname}_Bs[krow*${blockx} + cc] = b[krow*ldb + col]; + } +% else: + for (int idx = tid; idx < ${nb * blockx}; idx += ${nthreads}) + { + const int krow = ${kname}_brows[idx / ${blockx}]; const int cc = idx % ${blockx}; const int col = col_base + cc; - ${kname}_Bs[idx] = (krow < ${k} && col < n) ? b[krow*ldb + col] : (${dtype})0; + if (col < n) + ${kname}_Bs[krow*${blockx} + cc] = b[krow*ldb + col]; } +% endif __syncthreads(); % for w in range(msplit): From 7988c701306bc6e819241bd1e8fe4963b3097c5b Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 17:42:15 +0800 Subject: [PATCH 22/25] Compact MFMA m-split LDS tile to active k-tiles only The m-split path staged a full padded B tile in LDS, sized to every k-tile. Drop the inactive 4-wide k-slabs: the active k-tile at position a now occupies LDS rows [a*4, a*4+4), and the compute reads index it via compact_pos[kt]. Each tile keeps its four k-rows contiguous so the hardware g = lane/16 -> row mapping stays a plain offset (a full per-row compaction would need a runtime-g-indexed lookup, which risks the local-array-indexed-by-runtime register-allocation hazard). The cooperative load now uses two compile-time tables, bg[] (global k-row) and bl[] (compact LDS row), to scatter the bix rows into their tile slots; hole and padded rows inside an active tile are zeroed. LDS usage drops to n_akt*4*64*dsize, where n_akt is the number of active k-tiles (computed in hip.py for the shared-memory meta). Verified offline: a numerical emulation of the compacted staging reproduces A @ B (with beta) to <= 1.1e-14 for operands with zeroed k-slabs (k_tiles 6 -> 4 active, shared 12288 -> 8192), fully dense (unchanged), single active k-tile, and msplit 2/4. Smaller LDS means more resident blocks / higher occupancy, but only when A actually has all-zero k-slabs; k-dense operators see no change. On-device perf still to be benchmarked. --- gimmik/hip.py | 5 +++- gimmik/kernels/hip/mfma-dense.mako | 38 ++++++++++++++++++------------ 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 8673d37..c1a3112 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -52,10 +52,13 @@ def emit(name, args, meta): k_pad = k_tiles*4 bix_rows = sorted(self.bix) # k-rows A actually uses vec2 = self.aligne is not None and self.aligne % 2 == 0 + # active 4-wide k-tiles -> LDS holds only these (compacted) + n_akt = sum(any(amask[mt][kt] for mt in range(m_tiles)) + for kt in range(k_tiles)) for ms in self._mfma_msplits(m_tiles): # msplit goes in block.y (cf. bstream-msplit) so block.x stays # 64 = one wavefront = the cols-per-block grid contract. - shared = k_pad*blkx*dsize if ms > 1 else 0 + shared = n_akt*4*blkx*dsize if ms > 1 else 0 args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, 'k_tiles': k_tiles, 'amask': amask, 'msplit': ms, 'bix_rows': bix_rows, 'vec2': vec2} diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako index 94222d1..b999b28 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -106,42 +106,50 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## and the padded tail are zeroed (A is 0 there, so an MMA against 0 needs a ## finite -- not NaN -- operand). The global read is vectorized as f64x2 ## when the layout is 2-aligned. - __shared__ __align__(16) ${dtype} ${kname}_Bs[${k_pad * blockx}]; - const int tid = threadIdx.y*${blockx} + threadIdx.x; + ## LDS stores ONLY the active k-tiles (inactive 4-wide k-slabs are dropped): + ## active kt at position a occupies LDS rows [a*4, a*4+4). Only bix rows are + ## read from global (into their tile slot); hole/pad rows are zeroed so the + ## MMA never multiplies A=0 by an uninitialised (possibly NaN) operand. <% - rows_read = sorted({kt*4 + r for kt in active_kt for r in range(4)}) - loaded = set(bix_rows) - need_zero = any(r not in loaded for r in rows_read) - nb = len(bix_rows) + compact_pos = {kt: a for a, kt in enumerate(active_kt)} + n_akt = len(active_kt) + # for each used k-row: (global row, compact LDS row) + bload = [(kr, compact_pos[kr // 4]*4 + kr % 4) + for kr in bix_rows if (kr // 4) in compact_pos] + nb = len(bload) + need_zero = nb < n_akt*4 # any hole/pad row inside an active tile nthreads = blockx * msplit half = blockx // 2 %> + __shared__ __align__(16) ${dtype} ${kname}_Bs[${n_akt * 4 * blockx}]; + const int tid = threadIdx.y*${blockx} + threadIdx.x; % if need_zero: - for (int idx = tid; idx < ${k_pad * blockx}; idx += ${nthreads}) + for (int idx = tid; idx < ${n_akt * 4 * blockx}; idx += ${nthreads}) ${kname}_Bs[idx] = (${dtype})0; __syncthreads(); % endif - static const int ${kname}_brows[${nb}] = { ${', '.join(map(str, bix_rows))} }; + static const int ${kname}_bg[${nb}] = { ${', '.join(str(g) for g, _ in bload)} }; + static const int ${kname}_bl[${nb}] = { ${', '.join(str(l) for _, l in bload)} }; % if vec2: for (int idx = tid; idx < ${nb * half}; idx += ${nthreads}) { - const int krow = ${kname}_brows[idx / ${half}]; + const int r = idx / ${half}; const int cc = (idx % ${half}) * 2; const int col = col_base + cc; if (col + 1 < n) - *(gimmik_f64x2*)&${kname}_Bs[krow*${blockx} + cc] = - *(const gimmik_f64x2*)&b[krow*ldb + col]; + *(gimmik_f64x2*)&${kname}_Bs[${kname}_bl[r]*${blockx} + cc] = + *(const gimmik_f64x2*)&b[${kname}_bg[r]*ldb + col]; else if (col < n) - ${kname}_Bs[krow*${blockx} + cc] = b[krow*ldb + col]; + ${kname}_Bs[${kname}_bl[r]*${blockx} + cc] = b[${kname}_bg[r]*ldb + col]; } % else: for (int idx = tid; idx < ${nb * blockx}; idx += ${nthreads}) { - const int krow = ${kname}_brows[idx / ${blockx}]; + const int r = idx / ${blockx}; const int cc = idx % ${blockx}; const int col = col_base + cc; if (col < n) - ${kname}_Bs[krow*${blockx} + cc] = b[krow*ldb + col]; + ${kname}_Bs[${kname}_bl[r]*${blockx} + cc] = b[${kname}_bg[r]*ldb + col]; } % endif __syncthreads(); @@ -159,7 +167,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endfor % for kt in active_kt: % for t in range(tiles): - bv_${t} = ${kname}_Bs[(${kt*4} + g)*${blockx} + ${t*16} + p]; + bv_${t} = ${kname}_Bs[(${compact_pos[kt]*4} + g)*${blockx} + ${t*16} + p]; % endfor % for j, mt in enumerate(mts): % if amask[mt][kt]: From 1ee00bf0fcf29b06e9cd344818add3018ae6c311 Mon Sep 17 00:00:00 2001 From: "Eric.Chin.AMD" Date: Thu, 25 Jun 2026 17:57:08 +0800 Subject: [PATCH 23/25] k-block the MFMA m-split path so it fits (and wins) large-k operators On large-k operators (e.g. p5 tet m460: m~60, k~168) the m-split MFMA variants never ran: the single-chunk LDS B tile needed k_pad*4*64*8 = 84 KB, over the 64 KB limit, so only the scalar-B direct (s1) path was emitted -- leaving it ~4% behind a width-2 (double2) sparse kernel whose edge is exactly its vectorized B reads. Stage B into LDS in chunks of kc active k-tiles (kc = min(n_akt, 8), LDS fixed at 16 KB) and accumulate across the mako-unrolled chunk loop, so the m-split path fits any k and brings the double2 cooperative B load to large-k operators. To keep the cooperative (whole-block) load, cross-chunk accumulator persistence, and per-wavefront register scoping mutually compatible, the m-split path now uses a uniform m-tile assignment: wavefront owns m-tiles [wmt, wmt+mtpg) with wmt = threadIdx.y*mtpg evaluated at runtime, accumulators acc_j (j < mtpg) at function scope, and the real m-tile used at runtime for the Ag index and store row. The wmt+j < m_tiles guard is emitted only when m_tiles is not a multiple of msplit. This drops the compile-time zero-tile MMA skip on the m-split path (still present on the direct path); the bix B-load compaction is retained, so only the k-rows A uses are read. Verified offline: an emulation of the k-blocked path reproduces A @ B (with beta) to <= 5e-14 for multi-chunk large-k, single-chunk, zeroed-k-slab, non-divisible msplit, and beta cases; large-k operators now emit s1/s2/s4 with a bounded 16 KB shared allocation. On-device performance still to be benchmarked. --- gimmik/hip.py | 8 +- gimmik/kernels/hip/mfma-dense.mako | 186 ++++++++++++++--------------- 2 files changed, 97 insertions(+), 97 deletions(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index c1a3112..c2184d2 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -52,16 +52,18 @@ def emit(name, args, meta): k_pad = k_tiles*4 bix_rows = sorted(self.bix) # k-rows A actually uses vec2 = self.aligne is not None and self.aligne % 2 == 0 - # active 4-wide k-tiles -> LDS holds only these (compacted) + # active 4-wide k-tiles; the LDS m-split path stages them in + # chunks of kc tiles (k-blocking) so it fits any k. n_akt = sum(any(amask[mt][kt] for mt in range(m_tiles)) for kt in range(k_tiles)) + kc = min(n_akt, max(1, max_shared // (4*blkx*dsize) // 4)) or 1 for ms in self._mfma_msplits(m_tiles): # msplit goes in block.y (cf. bstream-msplit) so block.x stays # 64 = one wavefront = the cols-per-block grid contract. - shared = n_akt*4*blkx*dsize if ms > 1 else 0 + shared = kc*4*blkx*dsize if ms > 1 else 0 args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, 'k_tiles': k_tiles, 'amask': amask, 'msplit': ms, - 'bix_rows': bix_rows, 'vec2': vec2} + 'bix_rows': bix_rows, 'vec2': vec2, 'kc': kc} meta = {'block': (blkx, ms, 1), 'shared': shared, 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-s{ms}-x{blkx}'} yield from emit('mfma-dense', args, meta) diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense.mako index b999b28..d5a02b4 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense.mako @@ -2,30 +2,27 @@ ## ## Dense double-precision GEMM on the CDNA Matrix Cores (MFMA). ## -## Mirrors the NVIDIA PTX dense path: the constant operand A is densified, -## padded and baked into the kernel in Matrix-Core fragment order, B is -## streamed, and the epilogue is fully unrolled. Two knobs over v1: -## * zero-tile skipping -- amask[mt][kt] marks 16x4 A-tiles with a non-zero; -## all-zero tiles skip their MMA (and, on the direct path, the B load). -## * m-splitting -- msplit wavefronts per block (in block.y) each own a -## slice of the m-tiles, lowering per-wavefront accumulator pressure. -## For msplit>1 the B tile is staged once in LDS and shared by the whole -## block, so B is not re-read per wavefront. +## A is densified, padded and baked into the kernel in Matrix-Core fragment +## order; B is streamed; C is non-temporal stored; the epilogue is fully +## unrolled. Two code paths: +## msplit == 1 -- direct: one wavefront, B straight from global, compile-time +## zero-tile skipping (skip MMA + B load for all-zero tiles). +## msplit > 1 -- m-split + k-blocked: msplit wavefronts (block.y) each own a +## slice of the m-tiles; B is staged into LDS in chunks of kc +## active k-tiles (bounds LDS for any k) with a double2 +## cooperative copy, and only the k-rows A uses (bix) are read. ## ## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), g=lane/16, p=lane%16: -## A (16x4 ): A[i][kk] i=p, kk=g (1 reg/lane) -## B (4x16 ): B[kk][j] kk=g, j=p (1 reg/lane) -## C/D(16x16): D[i][j] j=p, i=4*reg + g (v4f64, 4 reg/lane) +## A (16x4 ): A[i][kk] i=p, kk=g +## B (4x16 ): B[kk][j] kk=g, j=p +## C/D(16x16): D[i][j] j=p, i=4*reg + g (v4f64) ## Bake: Ag[(mt*k_tiles+kt)*64 + lane] = A_pad[mt*16 + lane%16][kt*4 + lane//16] ## <% - tiles = blockx // 16 # 16-wide N-tiles per wavefront - k_pad = k_tiles * 4 + tiles = blockx // 16 active_kt = [kt for kt in range(k_tiles) if any(amask[mt][kt] for mt in range(m_tiles))] - mtpg = -(-m_tiles // msplit) # m-tiles per wavefront - warp_mts = [[mt for mt in range(w*mtpg, min((w+1)*mtpg, m_tiles))] - for w in range(msplit)] + mtpg = -(-m_tiles // msplit) %> typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); typedef ${dtype} gimmik_f64x2 __attribute__((ext_vector_type(2))); @@ -101,103 +98,104 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % endfor % else: - ## ---- m-split path: stage B in LDS once, share across msplit wavefronts ---- - ## Only the k-rows A actually uses (bix_rows) are read from global; holes - ## and the padded tail are zeroed (A is 0 there, so an MMA against 0 needs a - ## finite -- not NaN -- operand). The global read is vectorized as f64x2 - ## when the layout is 2-aligned. - ## LDS stores ONLY the active k-tiles (inactive 4-wide k-slabs are dropped): - ## active kt at position a occupies LDS rows [a*4, a*4+4). Only bix rows are - ## read from global (into their tile slot); hole/pad rows are zeroed so the - ## MMA never multiplies A=0 by an uninitialised (possibly NaN) operand. + ## ---- m-split + k-blocked path ---- <% - compact_pos = {kt: a for a, kt in enumerate(active_kt)} - n_akt = len(active_kt) - # for each used k-row: (global row, compact LDS row) - bload = [(kr, compact_pos[kr // 4]*4 + kr % 4) - for kr in bix_rows if (kr // 4) in compact_pos] - nb = len(bload) - need_zero = nb < n_akt*4 # any hole/pad row inside an active tile + chunks = [active_kt[c:c+kc] for c in range(0, len(active_kt), kc)] nthreads = blockx * msplit half = blockx // 2 + mt_guard = (m_tiles % msplit != 0) %> - __shared__ __align__(16) ${dtype} ${kname}_Bs[${n_akt * 4 * blockx}]; + __shared__ __align__(16) ${dtype} ${kname}_Bs[${kc * 4 * blockx}]; const int tid = threadIdx.y*${blockx} + threadIdx.x; -% if need_zero: - for (int idx = tid; idx < ${n_akt * 4 * blockx}; idx += ${nthreads}) + const int wmt = threadIdx.y*${mtpg}; + ${dtype} a, ${', '.join('bv_%d' % t for t in range(tiles))}; +% for j in range(mtpg): +% for t in range(tiles): + gimmik_f64x4 acc_${j}_${t} = {0.0, 0.0, 0.0, 0.0}; +% endfor +% endfor + +% for ci, chunk in enumerate(chunks): +<% + cpos = {kt: a for a, kt in enumerate(chunk)} + bload = [(kr, cpos[kr // 4]*4 + kr % 4) + for kr in bix_rows if (kr // 4) in cpos] + nb = len(bload) + need_zero = nb < len(chunk)*4 +%> +% if ci > 0: + __syncthreads(); +% endif +% if need_zero: + for (int idx = tid; idx < ${len(chunk)*4*blockx}; idx += ${nthreads}) ${kname}_Bs[idx] = (${dtype})0; __syncthreads(); -% endif - static const int ${kname}_bg[${nb}] = { ${', '.join(str(g) for g, _ in bload)} }; - static const int ${kname}_bl[${nb}] = { ${', '.join(str(l) for _, l in bload)} }; -% if vec2: - for (int idx = tid; idx < ${nb * half}; idx += ${nthreads}) - { - const int r = idx / ${half}; - const int cc = (idx % ${half}) * 2; - const int col = col_base + cc; - if (col + 1 < n) - *(gimmik_f64x2*)&${kname}_Bs[${kname}_bl[r]*${blockx} + cc] = - *(const gimmik_f64x2*)&b[${kname}_bg[r]*ldb + col]; - else if (col < n) - ${kname}_Bs[${kname}_bl[r]*${blockx} + cc] = b[${kname}_bg[r]*ldb + col]; - } -% else: - for (int idx = tid; idx < ${nb * blockx}; idx += ${nthreads}) +% endif { - const int r = idx / ${blockx}; - const int cc = idx % ${blockx}; - const int col = col_base + cc; - if (col < n) - ${kname}_Bs[${kname}_bl[r]*${blockx} + cc] = b[${kname}_bg[r]*ldb + col]; + static const int bg[${nb}] = { ${', '.join(str(x) for x, _ in bload)} }; + static const int bl[${nb}] = { ${', '.join(str(x) for _, x in bload)} }; +% if vec2: + for (int idx = tid; idx < ${nb * half}; idx += ${nthreads}) + { + const int r = idx / ${half}; + const int cc = (idx % ${half}) * 2; + const int col = col_base + cc; + if (col + 1 < n) + *(gimmik_f64x2*)&${kname}_Bs[bl[r]*${blockx} + cc] = + *(const gimmik_f64x2*)&b[bg[r]*ldb + col]; + else if (col < n) + ${kname}_Bs[bl[r]*${blockx} + cc] = b[bg[r]*ldb + col]; + } +% else: + for (int idx = tid; idx < ${nb * blockx}; idx += ${nthreads}) + { + const int r = idx / ${blockx}; + const int cc = idx % ${blockx}; + const int col = col_base + cc; + if (col < n) + ${kname}_Bs[bl[r]*${blockx} + cc] = b[bg[r]*ldb + col]; + } +% endif } -% endif __syncthreads(); - -% for w in range(msplit): -<% mts = warp_mts[w] %> -% if mts: - if (threadIdx.y == ${w}) - { - ${dtype} a, ${', '.join('bv_%d' % t for t in range(tiles))}; -% for j in range(len(mts)): -% for t in range(tiles): - gimmik_f64x4 acc_${j}_${t} = {0.0, 0.0, 0.0, 0.0}; -% endfor +% for a_pos, kt in enumerate(chunk): +% for t in range(tiles): + bv_${t} = ${kname}_Bs[(${a_pos*4} + g)*${blockx} + ${t*16} + p]; % endfor -% for kt in active_kt: +% for j in range(mtpg): +% if mt_guard: + if (wmt + ${j} < ${m_tiles}) +% endif + { + a = ${kname}_Ag[((wmt + ${j})*${k_tiles} + ${kt})*64 + lane]; % for t in range(tiles): - bv_${t} = ${kname}_Bs[(${compact_pos[kt]*4} + g)*${blockx} + ${t*16} + p]; -% endfor -% for j, mt in enumerate(mts): -% if amask[mt][kt]: - a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; -% for t in range(tiles): acc_${j}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv_${t}, acc_${j}_${t}, 0, 0, 0); -% endfor -% endif % endfor + } % endfor -% for j, mt in enumerate(mts): -% for t in range(tiles): -% for reg in range(4): - { - const int row = ${mt*16 + 4*reg} + g; - const int col = col_base + ${t*16} + p; - if (row < ${m} && col < n) +% endfor +% endfor + +% for j in range(mtpg): +% for t in range(tiles): +% for reg in range(4): +% if mt_guard: + if (wmt + ${j} < ${m_tiles}) +% endif + { + const int row = (wmt + ${j})*16 + ${4*reg} + g; + const int col = col_base + ${t*16} + p; + if (row < ${m} && col < n) % if beta == 0: - store_c(&c[row*ldc + col], acc_${j}_${t}[${reg}]); + store_c(&c[row*ldc + col], acc_${j}_${t}[${reg}]); % elif beta == 1: - store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${j}_${t}[${reg}])); + store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${j}_${t}[${reg}])); % else: - store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${j}_${t}[${reg}])); + store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${j}_${t}[${reg}])); % endif - } -% endfor -% endfor -% endfor } -% endif +% endfor +% endfor % endfor % endif } From 2c7af9b3dbbe022d1cffe5e4845ead079d398320 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Thu, 25 Jun 2026 11:18:27 +0000 Subject: [PATCH 24/25] Restore MI355 HIP baseline variants --- gimmik/hip.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/gimmik/hip.py b/gimmik/hip.py index 9365938..142a799 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -22,7 +22,21 @@ def emit(name, args, meta): def emit_preload(name, args, meta): yield from emit(name, args | {'preload': True}, meta) - blkx = self.basemeta['block'][0] + ms, bsz, blkx = 4, 24, 64 + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = { + 'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}' + } + yield from emit('bstream-msplit', args, meta) + + ks, csz, blkx = 2, 24, 64 + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = { + 'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}' + } + yield from emit('cstream-ksplit', args, meta) # Tuned HIP variants msplits, ksplits = [8, 4], [4, 2] From ffc8aff25d0ec33f67d7da0a0bdc50ceb5a7de42 Mon Sep 17 00:00:00 2001 From: tomjen12 Date: Wed, 1 Jul 2026 03:41:43 +0000 Subject: [PATCH 25/25] Clean up HIP MFMA dense integration Keep the dense MFMA path focused on the m-split implementation, rename the template accordingly, and remove the unused direct pipelined variant. --- gimmik/hip.py | 82 ++++--------- ...mfma-dense.mako => mfma-dense-msplit.mako} | 68 ++--------- gimmik/kernels/hip/mfma-dense-pipe.mako | 112 ------------------ 3 files changed, 30 insertions(+), 232 deletions(-) rename gimmik/kernels/hip/{mfma-dense.mako => mfma-dense-msplit.mako} (65%) delete mode 100644 gimmik/kernels/hip/mfma-dense-pipe.mako diff --git a/gimmik/hip.py b/gimmik/hip.py index 83bd246..d11a08c 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -40,45 +40,32 @@ def emit_preload(name, args, meta): } yield from emit('cstream-ksplit', args, meta) - # Dense f64 GEMM via the CDNA Matrix Cores (MFMA); see mfma-dense.mako. - # Modelled on the NVIDIA DMMA dense path: A is densified + baked in - # Matrix-Core fragment order, B is streamed, C is non-temporal stored. - # Densifying means it only pays off for reasonably dense operands, and - # the MFMA intrinsic is CDNA3-only (gfx94x). - if self._is_cdna3(gcn_arch) and self._mfma_dense_ok(dsize): + if dsize == 8: blkx = 64 a_hex, m_tiles, k_tiles, amask = self._mfma_dense_bake() - k_pad = k_tiles*4 bix_rows = sorted(self.bix) # k-rows A actually uses - vec2 = self.aligne is not None and self.aligne % 2 == 0 - # active 4-wide k-tiles; the LDS m-split path stages them in - # chunks of kc tiles (k-blocking) so it fits any k. - n_akt = sum(any(amask[mt][kt] for mt in range(m_tiles)) - for kt in range(k_tiles)) - kc = min(n_akt, max(1, max_shared // (4*blkx*dsize) // 4)) or 1 - for ms in self._mfma_msplits(m_tiles): - # msplit goes in block.y (cf. bstream-msplit) so block.x stays - # 64 = one wavefront = the cols-per-block grid contract. - shared = kc*4*blkx*dsize if ms > 1 else 0 - args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, - 'k_tiles': k_tiles, 'amask': amask, 'msplit': ms, - 'bix_rows': bix_rows, 'vec2': vec2, 'kc': kc} - meta = {'block': (blkx, ms, 1), 'shared': shared, - 'desc': f'mfma-dense/m{m_tiles}-k{k_tiles}-s{ms}-x{blkx}'} - yield from emit('mfma-dense', args, meta) - - # Software-pipelined (double-buffered B) direct variant: prefetch - # next k-tile's B while the current k-tile's MFMAs run. - args = {'blockx': blkx, 'a_hex': a_hex, 'm_tiles': m_tiles, - 'k_tiles': k_tiles, 'amask': amask} - meta = {'block': (blkx, 1, 1), 'shared': 0, - 'desc': f'mfma-dense-pipe/m{m_tiles}-k{k_tiles}-x{blkx}'} - yield from emit('mfma-dense-pipe', args, meta) - - # Only emit tuned variants on architectures they have been validated for. - base_arch = gcn_arch.split(':', 1)[0] if gcn_arch else None - if base_arch not in {'gfx90a', 'gfx942'} or warp_size != 64: - return + vec2_opts = [(False, '')] + if self.aligne is not None and self.aligne % 2 == 0: + vec2_opts.insert(0, (True, 'w2-')) + + for vec2, wpfx in vec2_opts: + for kc in [8, 16]: + shared = kc*4*blkx*dsize + for ms in [8, 16]: + args = { + 'blockx': blkx, 'a_hex': a_hex, + 'm_tiles': m_tiles, 'k_tiles': k_tiles, + 'amask': amask, 'msplit': ms, + 'bix_rows': bix_rows, 'vec2': vec2, 'kc': kc + } + meta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': ( + f'mfma-dense-msplit/{wpfx}' + f'm{m_tiles}-k{k_tiles}-s{ms}-kc{kc}-x{blkx}' + ) + } + yield from emit('mfma-dense-msplit', args, meta) # Tuned HIP variants msplits, ksplits = [8, 4], [4, 2] @@ -129,29 +116,6 @@ def emit_preload(name, args, meta): } | wmeta yield from emit_preload('cstream-ksplit', args, meta) - @staticmethod - def _is_cdna3(gcn_arch): - base = gcn_arch.split(':', 1)[0] if gcn_arch else None - return base in {'gfx940', 'gfx941', 'gfx942'} - - def _mfma_dense_ok(self, dsize): - # f64 Matrix Cores only (that is the only hard requirement of the - # mfma_f64_16x16x4 instruction). The kernel densifies A and is left - # for the autotuner to accept or reject on speed; the earlier - # m,k <= 128 and density >= 0.5 gates were too strict and hid it from - # real PyFR tet operators. Large m increases register pressure (each - # wavefront keeps m_tiles*4 v4f64 accumulators live) -> m-splitting is - # the natural follow-up if that becomes the bottleneck. - return dsize == 8 - - def _mfma_msplits(self, m_tiles): - # m-split factors to offer (placed in block.y). Each wavefront keeps - # m_tiles/msplit * 4 v4f64 accumulators live, so splitting m lowers - # register pressure / raises occupancy on large-m operators. msplit=1 - # is the direct (no-LDS) path; msplit>1 stages B once in LDS and shares - # it across the block (so B is not re-read per wavefront). - return [ms for ms in (1, 2, 4) if ms == 1 or ms <= m_tiles] - def _mfma_dense_bake(self): # Densify, pad and reorder A into v_mfma_f64_16x16x4 fragment order: # Ag[(mt*k_tiles + kt)*64 + lane] diff --git a/gimmik/kernels/hip/mfma-dense.mako b/gimmik/kernels/hip/mfma-dense-msplit.mako similarity index 65% rename from gimmik/kernels/hip/mfma-dense.mako rename to gimmik/kernels/hip/mfma-dense-msplit.mako index d5a02b4..6a8971e 100644 --- a/gimmik/kernels/hip/mfma-dense.mako +++ b/gimmik/kernels/hip/mfma-dense-msplit.mako @@ -3,14 +3,11 @@ ## Dense double-precision GEMM on the CDNA Matrix Cores (MFMA). ## ## A is densified, padded and baked into the kernel in Matrix-Core fragment -## order; B is streamed; C is non-temporal stored; the epilogue is fully -## unrolled. Two code paths: -## msplit == 1 -- direct: one wavefront, B straight from global, compile-time -## zero-tile skipping (skip MMA + B load for all-zero tiles). -## msplit > 1 -- m-split + k-blocked: msplit wavefronts (block.y) each own a -## slice of the m-tiles; B is staged into LDS in chunks of kc -## active k-tiles (bounds LDS for any k) with a double2 -## cooperative copy, and only the k-rows A uses (bix) are read. +## order; B is staged through LDS; C is non-temporal stored; the epilogue is +## fully unrolled. This m-split + k-blocked path uses msplit wavefronts +## (block.y), each owning a slice of the m-tiles. B is staged into LDS in +## chunks of kc active k-tiles so LDS usage is bounded for any k, and only the +## k-rows A uses (bix) are read. ## ## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), g=lane/16, p=lane%16: ## A (16x4 ): A[i][kk] i=p, kk=g @@ -49,56 +46,6 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) const int p = lane % 16; const int col_base = ${blockx}*blockIdx.x; -% if msplit == 1: - ## ---- direct path: single wavefront, B straight from global ---- - ${dtype} a; -% for t in range(tiles): - ${dtype} bv_${t}; -% endfor -% for mt in range(m_tiles): -% for t in range(tiles): - gimmik_f64x4 acc_${mt}_${t} = {0.0, 0.0, 0.0, 0.0}; -% endfor -% endfor -% for kt in active_kt: -<% krow_guard = (kt + 1)*4 > k %> -% for t in range(tiles): - { - const int col = col_base + ${t*16} + p; - const int krow = ${kt*4} + g; - bv_${t} = (col < n${' && krow < %d' % k if krow_guard else ''}) ? b[krow*ldb + col] : (${dtype})0; - } -% endfor -% for mt in range(m_tiles): -% if amask[mt][kt]: - a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; -% for t in range(tiles): - acc_${mt}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv_${t}, acc_${mt}_${t}, 0, 0, 0); -% endfor -% endif -% endfor -% endfor -% for mt in range(m_tiles): -% for t in range(tiles): -% for reg in range(4): - { - const int row = ${mt*16 + 4*reg} + g; - const int col = col_base + ${t*16} + p; - if (row < ${m} && col < n) -% if beta == 0: - store_c(&c[row*ldc + col], acc_${mt}_${t}[${reg}]); -% elif beta == 1: - store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${mt}_${t}[${reg}])); -% else: - store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${mt}_${t}[${reg}])); -% endif - } -% endfor -% endfor -% endfor - -% else: - ## ---- m-split + k-blocked path ---- <% chunks = [active_kt[c:c+kc] for c in range(0, len(active_kt), kc)] nthreads = blockx * msplit @@ -189,13 +136,12 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if beta == 0: store_c(&c[row*ldc + col], acc_${j}_${t}[${reg}]); % elif beta == 1: - store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${j}_${t}[${reg}])); + store_c(&c[row*ldc + col], load_c(&c[row*ldc + col]) + acc_${j}_${t}[${reg}]); % else: - store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${j}_${t}[${reg}])); + store_c(&c[row*ldc + col], ${beta}*load_c(&c[row*ldc + col]) + acc_${j}_${t}[${reg}]); % endif } % endfor % endfor % endfor -% endif } diff --git a/gimmik/kernels/hip/mfma-dense-pipe.mako b/gimmik/kernels/hip/mfma-dense-pipe.mako deleted file mode 100644 index c35bea0..0000000 --- a/gimmik/kernels/hip/mfma-dense-pipe.mako +++ /dev/null @@ -1,112 +0,0 @@ -<%inherit file='base'/> -## -## Dense f64 GEMM on the CDNA Matrix Cores (MFMA) -- software-pipelined variant. -## -## Same maths as mfma-dense (msplit=1 direct path): A densified + baked in -## fragment order, B streamed from global, C non-temporal stored, epilogue -## fully unrolled, zero-tile skipping via amask. The only difference: B for -## the NEXT k-tile is issued before the MFMAs of the CURRENT k-tile, so the -## global-load latency overlaps the Matrix-Core work (double-buffered B in -## registers, buffers 0/1 alternated per k-tile). -## -## Operand lane layout for v_mfma_f64_16x16x4_f64 (wave64), g=lane/16, p=lane%16: -## A (16x4 ): A[i][kk] i=p, kk=g -## B (4x16 ): B[kk][j] kk=g, j=p -## C/D(16x16): D[i][j] j=p, i=4*reg + g (v4f64) -## -<% - tiles = blockx // 16 - active_kt = [kt for kt in range(k_tiles) - if any(amask[mt][kt] for mt in range(m_tiles))] -%> -typedef ${dtype} gimmik_f64x4 __attribute__((ext_vector_type(4))); - -__device__ static const ${dtype} ${kname}_Ag[${m_tiles * k_tiles * 64}] = { - ${', '.join(a_hex)} -}; - -__global__ __launch_bounds__(${blockx}) void -% if n is None: -${kname}(int n, - const ${dtype}* __restrict__ b, int ldb, - ${dtype}* __restrict__ c, int ldc) -{ -% else: -${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) -{ - const int n = ${n}; - const ${'long long' if k * ldb >= 2**31 else 'int'} ldb = ${ldb}; - const ${'long long' if m * ldc >= 2**31 else 'int'} ldc = ${ldc}; -% endif - const int lane = threadIdx.x; - const int g = lane / 16; - const int p = lane % 16; - const int col_base = ${blockx}*blockIdx.x; - - ${dtype} a; - ${dtype} ${', '.join('bv0_%d' % t for t in range(tiles))}; - ${dtype} ${', '.join('bv1_%d' % t for t in range(tiles))}; -% for mt in range(m_tiles): -% for t in range(tiles): - gimmik_f64x4 acc_${mt}_${t} = {0.0, 0.0, 0.0, 0.0}; -% endfor -% endfor - -% if active_kt: -<% kt0 = active_kt[0]; g0 = (kt0 + 1)*4 > k %> - // prefetch the first k-tile into buffer 0 -% for t in range(tiles): - { - const int col = col_base + ${t*16} + p; - const int krow = ${kt0*4} + g; - bv0_${t} = (col < n${' && krow < %d' % k if g0 else ''}) ? b[krow*ldb + col] : (${dtype})0; - } -% endfor - -% for i, kt in enumerate(active_kt): -<% - cur = i % 2 - nxt = (i + 1) % 2 - has_next = i + 1 < len(active_kt) -%> -% if has_next: -<% knext = active_kt[i+1]; gN = (knext + 1)*4 > k %> - // prefetch k-tile ${knext} into buffer ${nxt} -% for t in range(tiles): - { - const int col = col_base + ${t*16} + p; - const int krow = ${knext*4} + g; - bv${nxt}_${t} = (col < n${' && krow < %d' % k if gN else ''}) ? b[krow*ldb + col] : (${dtype})0; - } -% endfor -% endif -% for mt in range(m_tiles): -% if amask[mt][kt]: - a = ${kname}_Ag[${(mt*k_tiles + kt)*64} + lane]; -% for t in range(tiles): - acc_${mt}_${t} = __builtin_amdgcn_mfma_f64_16x16x4f64(a, bv${cur}_${t}, acc_${mt}_${t}, 0, 0, 0); -% endfor -% endif -% endfor -% endfor -% endif - -% for mt in range(m_tiles): -% for t in range(tiles): -% for reg in range(4): - { - const int row = ${mt*16 + 4*reg} + g; - const int col = col_base + ${t*16} + p; - if (row < ${m} && col < n) -% if beta == 0: - store_c(&c[row*ldc + col], acc_${mt}_${t}[${reg}]); -% elif beta == 1: - store_c(&c[row*ldc + col], gimmik_vadd(load_c(&c[row*ldc + col]), acc_${mt}_${t}[${reg}])); -% else: - store_c(&c[row*ldc + col], gimmik_vadd(gimmik_vmul(${beta}, load_c(&c[row*ldc + col])), acc_${mt}_${t}[${reg}])); -% endif - } -% endfor -% endfor -% endfor -}