Add Fused Multi-Head Attention example#16
Add Fused Multi-Head Attention example#16AntonOresten wants to merge 10 commits intoJuliaGPU:mainfrom
Conversation
|
Seeing some weird erroring when branching (being fixed in #53): Click to see snippetsThis works: qk = if !EVEN_K[] && j >= mask_start
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
mask = mask .& (offs_n .<= k_seqlen)
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk .+ mask
else
qk
endbut this doesn't: if !EVEN_K[] && j >= mask_start
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
mask = mask .& (offs_n .<= k_seqlen)
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk = qk .+ mask
endnor does this: qk = if !EVEN_K[] && j >= mask_start
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
if !EVEN_K[]
mask .& (offs_n .<= k_seqlen)
end
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk .+ mask
else
qk
endIn the second and third block, I get "ERROR: SSAValue %___ not found in context" after removing the second condition, I can suddenly have a nested if block, and I don't need the outer else block: if !EVEN_K[]
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
if !EVEN_K[]
mask = mask .& (offs_n .<= k_seqlen)
end
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk = qk .+ mask
endDoes the if block need to depend on compile time constants? I'd need this to make the padding and causal mask properly. |
That's an IRStructurizer error. Can you provide an MWE? |
|
Currently needing to wrap outside Float32 constants in Float32 within the kernel because MulF otherwise sees it as qk_scale = Float32(qk_scale) * Float32(INV_LOG_2) |
|
Another concern is whether I should convert to |
Can you elaborate?
Yeah that's a common Julia pain point. It's why we have |
I define ERROR: LoadError: MethodError: no method matching encode_MulFOp!(::cuTile.CodeBuilder, ::cuTile.TypeId, ::cuTile.Value, ::Nothing)
The function `encode_MulFOp!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
encode_MulFOp!(::cuTile.CodeBuilder, ::cuTile.TypeId, ::cuTile.Value, ::cuTile.Value; rounding_mode, flush_to_zero)
@ cuTile ~/.julia/dev/cuTile/src/bytecode/encodings.jl:720
Oh, neat. I didn't know. I considered maybe a |
In general, Julia's array indexing requires |
|
Constants should work without the type conversion now. |
See outdated
Seems to fall slightly short of my NNop / ONIONop baseline (no WMMA), although I haven't compared it to the Python version. On my GPU, it compiles and runs fastest with tile_n=32 and tile_m=32:EDIT: this is without tensor cores. simply switching the compute type to TFloat32 / BFloat16 and exploring the optimization and entry hint landscape makes forward and backward passes ~10x faster.
Notably, cutile-python has aEDIT: fixed in #32 and #27.latencyargument forct.load, as well asnum_ctasandoccupancyarguments for the kernel, which might affect performance. The python version also does a kernel config autotune by searching a space of hand-picked configurations.Another thing that might be important for correctness or covering edge cases is exposing flush_to_zero? Used in e.g.
exp2. Haven't thought about in which cases this matters.