Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8e23d63
Add tuned HIP GiMMiK preload variants
tomjen12 Jun 18, 2026
8f4d03e
Fix HIP GiMMiK block size metadata
tomjen12 Jun 18, 2026
96671a6
Address HIP GiMMiK review comments
tomjen12 Jun 22, 2026
739a82e
Handle ROCm feature suffixes for gfx942 tuning
tomjen12 Jun 22, 2026
0633539
Enable tuned HIP variants on gfx90a
tomjen12 Jun 22, 2026
7b59fb0
Parameterize HIP vector width and refine preload kernels
tomjen12 Jun 23, 2026
e9b921a
Use blockx launch bounds for HIP cstream preload
tomjen12 Jun 23, 2026
2aa2577
Always use non-temporal C accesses for HIP
tomjen12 Jun 24, 2026
be1c1db
feat(hip): add non-temporal B-load (NTB) variants for bstream-msplit
EricKing626 Jun 24, 2026
280e948
Use non-temporal B loads by default for HIP
tomjen12 Jun 25, 2026
c06216d
Make HIP preload-C a template option
tomjen12 Jun 25, 2026
e014e4d
Avoid HIP vector operator+= overloads
tomjen12 Jun 25, 2026
9dfd072
Add f64 MFMA dense kernel for CDNA3 (gfx94x)
EricKing626 Jun 25, 2026
6d237ef
Update mfma-dense.mako
EricKing626 Jun 25, 2026
f6bc308
Prune HIP tuned variants to 12
tomjen12 Jun 25, 2026
6689a9c
Update mfma-dense.mako
EricKing626 Jun 25, 2026
a3aee45
Remove HIP variant arch gate
tomjen12 Jun 25, 2026
b521427
Update hip.py
EricKing626 Jun 25, 2026
3390912
Add m-splitting and zero-tile skipping to MFMA dense kernel
EricKing626 Jun 25, 2026
99deb2e
Add software-pipelined (double-buffered B) MFMA dense variant
EricKing626 Jun 25, 2026
1e554de
Cut B traffic in MFMA m-split path with bix-compacted, vectorized LDS…
EricKing626 Jun 25, 2026
7988c70
Compact MFMA m-split LDS tile to active k-tiles only
EricKing626 Jun 25, 2026
1ee00bf
k-block the MFMA m-split path so it fits (and wins) large-k operators
EricKing626 Jun 25, 2026
2c7af9b
Restore MI355 HIP baseline variants
tomjen12 Jun 25, 2026
9e289a1
Merge branch 'hip-gimmik-mfma-dense' into hip-gimmik-mfma-dense-pr
tomjen12 Jul 1, 2026
ffc8aff
Clean up HIP MFMA dense integration
tomjen12 Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 124 additions & 10 deletions gimmik/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,143 @@

from gimmik.base import MatMul

import numpy as np


class HIPMatMul(MatMul):
platform = 'hip'
basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0}

def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64):
# B loading, C streaming kernel
yield ('cstream', {}, {})
max_block_threads = 1024
max_shared = 64*1024

def emit(name, args, meta):
block = meta.get('block', self.basemeta['block'])
shared = meta.get('shared', self.basemeta['shared'])
threads = block[0]*block[1]*block[2]

if threads <= max_block_threads and shared <= max_shared:
yield (name, args, meta)

# B streaming, C accumulation kernel
yield ('bstream', {}, {})
def emit_preload(name, args, meta):
yield from emit(name, args | {'preload': True}, meta)

# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 24, 64
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize}
yield ('bstream-msplit', args, meta)
meta = {
'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize,
'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'
}
yield from emit('bstream-msplit', args, meta)

# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 24, 64
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)
meta = {
'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize,
'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'
}
yield from emit('cstream-ksplit', args, meta)

if dsize == 8:
blkx = 64
a_hex, m_tiles, k_tiles, amask = self._mfma_dense_bake()
bix_rows = sorted(self.bix) # k-rows A actually uses
vec2_opts = [(False, '')]
if self.aligne is not None and self.aligne % 2 == 0:
vec2_opts.insert(0, (True, 'w2-'))

for vec2, wpfx in vec2_opts:
for kc in [8, 16]:
shared = kc*4*blkx*dsize
for ms in [8, 16]:
args = {
'blockx': blkx, 'a_hex': a_hex,
'm_tiles': m_tiles, 'k_tiles': k_tiles,
'amask': amask, 'msplit': ms,
'bix_rows': bix_rows, 'vec2': vec2, 'kc': kc
}
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': (
f'mfma-dense-msplit/{wpfx}'
f'm{m_tiles}-k{k_tiles}-s{ms}-kc{kc}-x{blkx}'
)
}
yield from emit('mfma-dense-msplit', args, meta)

# Tuned HIP variants
msplits, ksplits = [8, 4], [4, 2]
bsz, csz, blkx = 8, 8, 64
widths = [1]
if self.aligne is not None and self.aligne % 2 == 0:
widths.insert(0, 2)

for width in widths:
wargs = ({'dtype': f'{dtype}{width}', 'width': width}
if width > 1 else {})
wmeta = {'width': width} if width > 1 else {}
wpfx = f'w{width}-' if width > 1 else ''

for ms in msplits:
# m-split B streaming, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs
shared = 2*bsz*blkx*dsize*width
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': f'bstream-msplit/{wpfx}m{ms}-b{bsz}-x{blkx}'
} | wmeta
yield from emit('bstream-msplit', args, meta)

for ms in msplits:
# m-split B streaming, C preloading, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs
shared = 2*bsz*blkx*dsize*width
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': (
f'bstream-msplit-preload-c/'
f'{wpfx}m{ms}-b{bsz}-x{blkx}'
)
} | wmeta
yield from emit_preload('bstream-msplit', args, meta)

for ks in ksplits:
# k-split B loading, C preloading, C streaming kernel
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs
shared = (ks - 1)*csz*blkx*dsize*width
meta = {
'block': (blkx, ks, 1), 'shared': shared,
'desc': (
f'cstream-ksplit-preload-c/'
f'{wpfx}k{ks}-c{csz}-x{blkx}'
)
} | wmeta
yield from emit_preload('cstream-ksplit', args, meta)

def _mfma_dense_bake(self):
# Densify, pad and reorder A into v_mfma_f64_16x16x4 fragment order:
# Ag[(mt*k_tiles + kt)*64 + lane]
# = A_pad[mt*16 + lane%16][kt*4 + lane//16]
# i.e. with lane = g*16 + p, operand A wants i = p, kk = g.
# amask[mt][kt] flags 16x4 A-tiles that contain a non-zero, so the
# kernel can skip the MMA (and, on the direct path, the B load) for
# all-zero tiles -- structural zero-tile skipping.
m, k = self.A.shape
m_tiles = -(-m // 16)
k_tiles = -(-k // 4)
a_pad = np.zeros((m_tiles*16, k_tiles*4), dtype=np.float64)
a_pad[:m, :k] = self.A
a_hex = []
for mt in range(m_tiles):
for kt in range(k_tiles):
for lane in range(64):
i = mt*16 + (lane % 16)
kk = kt*4 + (lane // 16)
a_hex.append(float(a_pad[i, kk]).hex())
amask = [[bool(np.any(a_pad[mt*16:mt*16+16, kt*4:kt*4+4]))
for kt in range(k_tiles)] for mt in range(m_tiles)]
return a_hex, m_tiles, k_tiles, amask

def _process_meta(self, meta):
if self.n is not None:
Expand Down
68 changes: 65 additions & 3 deletions gimmik/kernels/hip/base.mako
Original file line number Diff line number Diff line change
@@ -1,12 +1,74 @@
% if dtype.endswith('4'):
static inline __device__ ${dtype} make_zero()
inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b)
{ return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); }

inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b)
{ return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); }

inline __device__ ${dtype} make_zero()
{ return make_${dtype}(0, 0, 0, 0); }
% elif dtype.endswith('2'):
static inline __device__ ${dtype} make_zero()
inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b)
{ return make_${dtype}(a.x + b.x, a.y + b.y); }

inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b)
{ return make_${dtype}(a*b.x, a*b.y); }

inline __device__ ${dtype} make_zero()
{ return make_${dtype}(0, 0); }
% else:
static inline __device__ ${dtype} make_zero()
inline __device__ ${dtype} make_zero()
{ return 0; }
% endif

static inline __device__ void
nt_store(${dtype}* p, ${dtype} v)
{
% if dtype.endswith('4'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
__builtin_nontemporal_store(v.z, &p->z);
__builtin_nontemporal_store(v.w, &p->w);
% elif dtype.endswith('2'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
% else:
__builtin_nontemporal_store(v, p);
% endif
}

static inline __device__ ${dtype}
nt_load(const ${dtype}* p)
{
% if dtype.endswith('4'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y),
__builtin_nontemporal_load(&p->z),
__builtin_nontemporal_load(&p->w));
% elif dtype.endswith('2'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y));
% else:
return __builtin_nontemporal_load(p);
% endif
}

static inline __device__ void
store_c(${dtype}* p, ${dtype} v)
{
nt_store(p, v);
}

static inline __device__ ${dtype}
load_c(const ${dtype}* p)
{
return nt_load(p);
}

static inline __device__ ${dtype}
load_b(const ${dtype}* p)
{
return nt_load(p);
}

${next.body()}
38 changes: 28 additions & 10 deletions gimmik/kernels/hip/bstream-msplit.mako
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<%
mx = partition(A, into=msplit, by='rows')
bchunks = chunk(bix, bsz)
preload = context.get('preload', False)
%>

__global__ __launch_bounds__(${blockx*msplit}) void
Expand All @@ -12,7 +13,7 @@ ${kname}(int n,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = ((n + ${width} - 1) / ${width}) * ${width};
n = (n + ${width} - 1) / ${width};
ldb /= ${width};
ldc /= ${width};
% endif
Expand All @@ -34,9 +35,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
% for kx in bchunks[0]:
% if loop.index % msplit == cid:
bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]);
% endif
% endfor

% if preload and beta != 0:
## Preload C values for active rows owned by this m-split lane
% for j, jx in enumerate(mx[cid]):
% if afix[jx] != -1:
% if beta == 1:
csub[${j}] = load_c(&c[i + ${jx}*ldc]);
% else:
csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]);
% endif
% endif
% endfor
% endif
}
% endfor
__syncthreads();
Expand All @@ -51,36 +65,40 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
% if not loop.parent.last:
% for kx in bchunks[bb + 1]:
% if loop.index % msplit == cid:
bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]);
% endif
% endfor
% endif
## Accumulate our dot products
% for kx in bchunks[bb]:
bv = bsub[${bb % 2}][${loop.index}][threadIdx.x];
% for j, jx in enumerate(A[mcx, kx]):
% if jx != 0 and kx == afix[mcx[j]]:
% if preload and beta != 0 and jx != 0:
csub[${j}] += ${jx}*bv;
% elif jx != 0 and kx == afix[mcx[j]]:
csub[${j}] = ${jx}*bv;
% elif jx != 0:
csub[${j}] += ${jx}*bv;
% endif
## If we're done with this dot product then store to global
% if kx == alix[mcx[j]] and beta == 0:
c[i + ${mcx[j]}*ldc] = csub[${j}];
% if preload and kx == alix[mcx[j]]:
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% elif kx == alix[mcx[j]] and beta == 0:
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% elif kx == alix[mcx[j]] and beta == 1:
c[i + ${mcx[j]}*ldc] += csub[${j}];
store_c(&c[i + ${mcx[j]}*ldc], load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]);
% elif kx == alix[mcx[j]]:
c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc];
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${mcx[j]}*ldc]));
% endif
% endfor
% endfor
## Handle rows of A which are all zero
% if loop.parent.last:
% for j, jx in enumerate(afix):
% if jx == -1 and j % msplit == cid and beta == 0:
c[i + ${j}*ldc] = make_zero();
store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and j % msplit == cid and beta != 1:
c[i + ${j}*ldc] *= ${beta};
store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc]));
% endif
% endfor
% endif
Expand Down
41 changes: 30 additions & 11 deletions gimmik/kernels/hip/bstream.mako
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
<%inherit file='base'/>

__global__ __launch_bounds__(128) void
<% preload = context.get('preload', False) %>

__global__ __launch_bounds__(${blockx}) void
% if n is None:
${kname}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = ((n + ${width} - 1) / ${width}) * ${width};
n = (n + ${width} - 1) / ${width};
ldb /= ${width};
ldc /= ${width};
% endif
Expand All @@ -24,32 +26,49 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
${dtype} bv, csub[${m}];

## Iterare through the used rows of B
% if preload and beta != 0:
## Preload C values for rows which will receive a non-zero dot product
% for j, jx in enumerate(afix):
% if jx != -1:
% if beta == 1:
csub[${j}] = load_c(&c[i + ${j}*ldc]);
% else:
csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]);
% endif
% endif
% endfor
% endif

## Iterate through the used rows of B
% for kx in bix:
bv = b[i + ${kx}*ldb];
bv = load_b(&b[i + ${kx}*ldb]);
% for j, jx in enumerate(A[:, kx]):
% if jx != 0 and kx == afix[j]:
% if preload and beta != 0 and jx != 0:
csub[${j}] += ${jx}*bv;
% elif jx != 0 and kx == afix[j]:
csub[${j}] = ${jx}*bv;
% elif jx != 0:
csub[${j}] += ${jx}*bv;
% endif
##
% if kx == alix[j] and beta == 0:
c[i + ${j}*ldc] = csub[${j}];
% if preload and kx == alix[j]:
store_c(&c[i + ${j}*ldc], csub[${j}]);
% elif kx == alix[j] and beta == 0:
store_c(&c[i + ${j}*ldc], csub[${j}]);
% elif kx == alix[j] and beta == 1:
c[i + ${j}*ldc] += csub[${j}];
store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + csub[${j}]);
% elif kx == alix[j]:
c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc];
store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${j}*ldc]));
% endif
% endfor
% endfor

## Handle rows of A which are all zero
% for j, jx in enumerate(afix):
% if jx == -1 and beta == 0:
c[i + ${j}*ldc] = make_zero();
store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and beta != 1:
c[i + ${j}*ldc] *= ${beta};
store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc]));
% endif
% endfor
}
Expand Down
Loading