The paged_attention_ll4mi_QKV_mfma16_kernel has this structure:
- Each warp loads Q head elements into shared memory
- Layout depends on
KV_DTYPE: FP16/BF16 (kAuto) vs FP8 (different layout for 32-wide MFMA) - turbo4 impact: None — Q is always in native dtype
- For each KV block, load K from paged cache via 16-byte loads
cache_t* k_ptr3 = k_ptr + kphysical_block_offset * kv_seq_stride- 16B vectorized load:
Klocal[...] = *k_fetch_ptr_16B(reinterpret_cast to_B16x8*) - For FP8:
sizeof(cache_t)=1, so 16B = 16 elements; usesmfma16x16x32FP8 instruction - For FP16:
sizeof(cache_t)=2, so 16B = 8 elements; usesmfma16x16x16FP16 instruction - turbo4 impact: CRITICAL — need to load nibble-packed data + dequant before MFMA
gcn_mfma16x16x32_instr<__hip_fp8_e4m3>for FP8gcn_mfma16x16x16_instr<scalar_t>for FP16/BF16- K scale applied to softmax scale (not per-element)
- turbo4 impact: After dequant, K is FP16 → use FP16 MFMA path
- Online softmax with attention logits
- turbo4 impact: None — same computation
- Similar to K: load from paged cache, then MFMA with softmax weights
- V scale applied post-MFMA for FP8
- turbo4 impact: Same as K — need nibble dequant before MFMA
- Write to output tensor
- turbo4 impact: None
Current FP8 K load:
const cache_t* k_fetch_ptr = k_ptr3 + offset1 * KX + offset2;
const _B16x8* k_fetch_ptr_16B = reinterpret_cast<const _B16x8*>(k_fetch_ptr);
Klocal[head_loop][token_depth][qkhe_depth] = *k_fetch_ptr_16B;Proposed turbo4 K load:
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kTurbo4) {
// Load nibble-packed bytes for this head element range
// Each thread processes CONTIGUOUS_KV_ELEMS_16B_LOAD dims (16 for FP8, 8 for FP16)
// For turbo4: load 8 bytes (16 nibbles = 16 dims) per thread
const int dim_start = head_elem; // starting dimension
const int nibble_byte_start = dim_start / 2;
const uint8_t* k_nibble_ptr = reinterpret_cast<const uint8_t*>(k_ptr3) + nibble_byte_start;
// Load 8 bytes = 16 nibbles = 16 dims
uint64_t packed = *reinterpret_cast<const uint64_t*>(k_nibble_ptr);
// Load norm (2 bytes at offset 64)
const uint8_t* norm_ptr = reinterpret_cast<const uint8_t*>(k_ptr3) + 64;
half norm_h = *reinterpret_cast<const half*>(norm_ptr);
float norm = __half2float(norm_h);
// Extract 16 nibbles and lookup centroids
half k_vals[16];
for (int i = 0; i < 8; i++) {
uint8_t byte = (packed >> (i * 8)) & 0xFF;
int idx_lo = byte & 0xF;
int idx_hi = (byte >> 4) & 0xF;
k_vals[2*i] = __float2half(centroids[idx_lo] * norm);
k_vals[2*i+1] = __float2half(centroids[idx_hi] * norm);
}
// Store as _B16x8 (16 bytes = 8 FP16 values)
Klocal[head_loop][token_depth][qkhe_depth] =
*reinterpret_cast<_B16x8*>(k_vals);
} else {
// Original path
Klocal[...] = *k_fetch_ptr_16B;
}Same pattern as K — load nibble bytes, extract, lookup, multiply by norm.
elif kv_cache_dtype == "turbo4":
if query.dtype == torch.bfloat16:
dtype = "__hip_bfloat16"
kv_dtype = "uint8_t" # same as FP8 in terms of C++ type
elif query.dtype == torch.float16:
dtype = "_Float16"
kv_dtype = "uint8_t"
fp8_kv_dtype = "turbo4" # new enum valueenum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
kTurbo4 = 3, // NEW: 4-bit nibble PolarQuant
};template <...>
__inline__ __device__ void _paged_attention_kernel(
...,
const float* __restrict__ centroids, // NEW: [16] turbo4 centroids
...
)FP8 uses mfma16x16x32 (32 elements per instruction). After turbo4 dequant, data is FP16, so we'd use mfma16x16x16 (16 elements per instruction). This means:
- K fetch loads fewer elements per instruction (16 vs 32 for FP8)
- But loads fewer bytes (8 vs 16) since packed
- Net: similar throughput
Each position has a per-position norm (2 bytes). This is an extra memory load per position that FP8 doesn't have. Options:
- Load norm once per position, broadcast to all dims (register)
- Pre-multiply centroids by norm → 16 scaled centroids per position
16 centroids × 4 bytes = 64 bytes. Options:
- Constant memory — fastest, but limited (64KB total)
- Shared memory — fast, 64 bytes per block (negligible)
- Registers — each thread loads all 16 (64 bytes in registers)
Recommendation: Shared memory — load once at kernel start, access is fast.
| Component | FP8 (current) | turbo4 (proposed) |
|---|---|---|
| K bytes loaded | 128B/pos (16B × 8 threads) | 64B/pos (8B × 8 threads) + 2B norm |
| K dequant ALU | 0 (hardware FP8→FP32) | 16 shifts + 16 masks + 16 centroid lookups + 16 muls |
| K MFMA | mfma16x16x32 (FP8, 32-wide) | mfma16x16x16 (FP16, 16-wide) |
| V bytes loaded | Same as K | Same as K |
| Total data | 256B/pos | 132B/pos (48% less) |
| Extra ALU | 0 | ~128 ops/pos (shifts+masks+gathers+muls) |
Estimated turbo4 AITER: 70-100μs/layer (1.5-2x FP8's 48.3μs)
The 48% less data should help, but the extra ALU for nibble extraction and centroid lookup will partially offset the bandwidth savings. The key question is whether the ALU can be hidden behind the memory pipeline.
- Add
kTurbo4toFp8KVCacheDataTypeenum inpa_common.cuh - Add
centroidskernel argument inpa_v1.cpp.jinjaandpa_v1.cuh - Add turbo4 K dequant in K fetch loop (lines 221-240)
- Add turbo4 V dequant in V fetch loop (lines 700-740)
- Select FP16 MFMA path when turbo4 (not FP8 MFMA)
- Add
turbo4handling inpa_v1.py(Python JIT build) - Test with standalone benchmark