Near-optimal KV cache quantization for LLM serving on AMD GPUs
TurboQuant compresses the KV cache of large language models to 2–4 bits per element with near-optimal distortion, enabling 2× KV cache capacity and 71% higher throughput at high concurrency on AMD Instinct GPUs.
Based on the paper TurboQuant: Online Vector Quantization for KV Cache Quantization (ICLR 2026).
| turbo4 (4-bit) | turbo3 (3-bit) | turbo2 (2-bit) | |
|---|---|---|---|
| Compression | 3.8× vs FP16 | 4.6× vs FP16 | 5.7× vs FP16 |
| PPL vs q8_0 | +0.23% | +1.06% | +2.8% |
| GSM8K | 95% (= FP8) | 95% (= FP8) | — |
| Recommended for | Quality-first | Balanced | Max compression |
At ISL=31K, concurrency=64 (the workloads that matter for production):
| Metric | TurboQuant | Vanilla FP8 | Improvement |
|---|---|---|---|
| Output throughput | 71.60 tok/s | 41.92 tok/s | +71% |
| Mean TTFT | 33.5s | 965.2s | 29× faster |
| KV cache capacity | 2.59M tokens | 1.29M tokens | 2.0× |
| Failed requests | 0/640 | 1/640 | — |
TurboQuant's 2× KV capacity means more requests run simultaneously instead of queueing. At high concurrency, vanilla FP8 throughput collapses while TQ holds steady.
Input K/V (FP16) ──→ WHT Rotate ──→ PolarQuant ──→ Pack ──→ Compressed KV
│ │ │
│ Gaussianize │ Lloyd-Max │ 2-4 bit indices
│ coordinates │ quantization │ + norms + signs
│ │ │
▼ ▼ ▼
Orthogonal Near-optimal 4.6× smaller
(self-inverse) distortion (turbo3)
-
WHT Rotation: The Walsh-Hadamard Transform maps structured vectors into ones with approximately i.i.d. Gaussian coordinates. This is the key insight — Gaussian coordinates enable near-optimal scalar quantization.
-
PolarQuant: Decomposes each rotated vector into norm (scalar) + direction (quantized per-coordinate via Lloyd-Max codebook). Achieves distortion within ~2.7× of the information-theoretic lower bound.
-
Sign Extraction (turbo2/3): Stores 1-bit QJL correction signs for unbiased inner product estimation. Not needed for turbo4 (16 centroids give sufficient quality).
-
Bit Packing: Packs indices into bytes with stride-4 interleaving for efficient GPU access patterns.
For keys, TurboQuant computes attention scores directly from compressed data:
<q, k> ≈ <q, k_mse> + ||r_k|| × √(π/2)/m × <S@q, sign(S@r_k)>
The MSE term uses centroid reconstruction; the QJL correction provides an unbiased inner product estimate. For values, MSE-only reconstruction suffices (errors average out in the weighted sum).
WHT rotation cost can be completely eliminated by fusing the rotation matrix into model weights at load time:
from turboquant.core.rotation import apply_weight_fusion
state_dict = apply_weight_fusion(state_dict, num_heads=64, num_kv_heads=8, head_size=128)
# Rotation cost: 68μs/layer → 0μs/layer
# Cosine similarity: 1.000000 (mathematically exact)Seven fused Triton kernels handle the full pipeline on AMD Instinct GPUs:
| Kernel | File | Function |
|---|---|---|
| A | kernels/compress_k.py |
Fused K compression (norm + rotate + quantize + QJL) |
| B | kernels/compress_v.py |
Fused V compression (norm + rotate + quantize) |
| C | kernels/attention.py |
Asymmetric attention decode (split-K) |
| A' | kernels/packed_compress.py |
Packed K+V compression (bit-packed output) |
| B' | kernels/packed_attention.py |
Packed attention v7 (production decode kernel) |
| C' | kernels/sparse_v.py |
Sparse-V v12 decode (skip near-zero weights) |
| Kernel | Time/Layer | Platform |
|---|---|---|
| TQ turbo3 v12 (Sparse-V) | 843 μs | MI355X |
| TQ turbo4 | 797 μs | MI355X |
| AITER FP8 PA (target) | 48 μs | MI355X |
| GPU | Architecture | Status |
|---|---|---|
| AMD Instinct MI355X | gfx950 | ✅ Primary target |
| AMD Instinct MI300X | gfx942 | ✅ Tested |
| ROCm | 7.0+ | Required |
# From source
git clone https://github.com/andyluo7/turboquant-amd.git
cd turboquant-amd
pip install -e ".[dev]"
# Or directly
pip install git+https://github.com/andyluo7/turboquant-amd.git- Python ≥ 3.10
- PyTorch ≥ 2.4.0 (with ROCm support)
- Triton ≥ 3.0.0
- AMD Instinct GPU (MI300X or MI355X)
- ROCm ≥ 7.0
from turboquant.core import make_codebook, make_wht_matrix, PolarQuantConfig
# Create optimal codebook for 3-bit quantization
centroids, boundaries = make_codebook(d=128, bits=3, device="cuda")
# WHT rotation matrix
H = make_wht_matrix(128, device="cuda") # H @ H = I (self-inverse)
# Configure compression
config = PolarQuantConfig(head_dim=128, bits=3, use_qjl=True)
print(f"Compression ratio: {config.compression_ratio:.1f}x") # 4.6ximport torch
from turboquant.core import polarquant_compress
# Simulate KV cache entries
K = torch.randn(1024, 128, device="cuda")
# Compress
compressed = polarquant_compress(K, centroids, boundaries, config, rotation_matrix=H)
# compressed["indices"]: [1024, 128] uint8
# compressed["norms"]: [1024] float
# compressed["signs"]: [1024, 128] ±1
# compressed["k_mse_rot"]: [1024, 128] MSE reconstructionpytest tests/ -v# Kernel-level benchmarks (requires GPU)
python benchmarks/bench_kernels.py
# E2E serving benchmark (requires vLLM)
python benchmarks/bench_e2e.py --tq-port 8000 --vanilla-port 8001TurboQuant integrates with vLLM as a custom KV cache backend:
# Launch with TurboQuant KV cache
python -m vllm.entrypoints.openai.api_server \
--model MiniMax-Text-01 \
--tensor-parallel-size 2 \
--kv-cache-dtype turboquantThe turbo4→FP8 pipeline compresses KV entries via PolarQuant and stores them as standard FP8, enabling use of AITER's optimized FP8 paged attention for decode.
Experimental turbo4 support via custom GGML quantization types. See turboquant/integration/llamacpp.py for details.
Results on MI300X:
- Prefill: +4% vs f16
- Decode: 89% of f16
- PPL: +0.23% vs q8_0
| Benchmark | TurboQuant | FP8 | Notes |
|---|---|---|---|
| GSM8K 5-shot | 95% | 95% | MiniMax-M2.5, TP=2, MI355X |
| WHT roundtrip | 2.98e-07 max err | — | Numerically exact |
| Weight fusion | cosine = 1.0 | — | Mathematically exact |
| turbo4 PPL | +0.23% vs q8_0 | — | Negligible |
| turbo3 PPL | +1.06% vs q8_0 | — | Very good |
turboquant-amd/
├── turboquant/
│ ├── core/ # Codebook, rotation, quantizer
│ ├── kernels/ # Fused Triton GPU kernels
│ ├── integration/ # vLLM + llama.cpp backends
│ └── reference/ # PyTorch reference implementation
├── benchmarks/ # Kernel and E2E benchmarks
├── tests/ # Unit tests
└── docs/ # Architecture, kernel history, results
We welcome contributions! Here are the key areas where help is needed:
| Task | Description | Difficulty | Impact |
|---|---|---|---|
| AITER turbo4 PA kernel | Fork AITER's HIP paged attention to read turbo4 (4-bit nibble) KV directly in the attention inner loop. Target: match AITER FP8 speed (~48μs/layer). See design doc. | Hard | 10x decode speedup |
| Native FP4 MFMA on gfx950 | Use mfma_scale_f32_16x16x128_f8f6f4 for hardware FP4 decode on MI355X. Currently returns zeros on ROCm 7.0.0 — needs testing on 7.0.1+. See design doc. |
Hard | 2x capacity at FP8 speed |
| Triton codegen optimization | Current Triton kernels achieve only 3% HBM bandwidth utilization on ROCm (vs AITER's 61%). Profile and optimize the generated HIP assembly. | Hard | 3-10x kernel speedup |
| Task | Description | Difficulty | Impact |
|---|---|---|---|
| Boundary layer protection | Skip first 2 + last 2 layers (keep at full precision). Proven to recover 37-91% of quality gap. Easy to add. | Easy | Quality improvement |
| Outlier-aware mixed precision | Per-layer channel variance calibration → high-variance channels get more bits (paper Section 4.3). See Hyperloom reference. | Medium | Quality improvement |
| Asymmetric K/V compression | Compress V more aggressively than K (V compression is free — confirmed by turboquant_plus). E.g., turbo4-K + turbo2-V. | Easy | Better compression ratio |
| Perplexity benchmark suite | Add wikitext-2 / wikitext-103 PPL evaluation for all turbo configs. Currently only have GSM8K accuracy. | Easy | Quality validation |
| More model support | Add set_dflash_layers_to_capture equivalent for DeepSeek-V3, Llama-4, Gemma-4 in vLLM. |
Easy | Broader adoption |
| Task | Description | Difficulty | Impact |
|---|---|---|---|
| SGLang integration | Port the vLLM backend to SGLang's attention framework. | Medium | New serving engine |
| llama.cpp turbo4 Metal kernel | Register LUT decode for Apple Silicon (M-series). TheTom's fork already has turbo3. | Medium | Apple Silicon support |
| ONNX export | Export TQ-compressed models to ONNX with custom ops. | Medium | Cross-platform |
| Quantization-aware training | Fine-tune with TQ compression in the loop for better quality. | Hard | Quality at lower bits |
| Dynamic bit allocation | Per-layer bit-width selection based on attention entropy. | Medium | Adaptive compression |
| Task | Description | Difficulty |
|---|---|---|
| CI/CD pipeline | GitHub Actions with MI300X self-hosted runner for kernel tests | Medium |
| Benchmark dashboard | Automated throughput/quality tracking across commits | Medium |
| pip installable | Make pip install turboquant-amd work cleanly |
Easy |
| Documentation | API docs, tutorials, integration guides | Easy |
| Issue | Description | Workaround |
|---|---|---|
| Triton ROCm codegen | Only 3% HBM bandwidth utilization vs AITER's 61% | Use turbo4→FP8 pipeline for production |
mfma_scale zeros |
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 returns all zeros on ROCm 7.0.0 |
Needs ROCm 7.0.1+ testing |
| Sparse-V v12 memory faults | Intermittent GPU memory access faults at ISL=31K conc=64 on MI355X | Root cause: Triton codegen + GPU-pair sensitivity |
| AITER broken in SGLang Docker | Public SGLang ROCm images have broken AITER imports | pip install -e . inside container |
| vLLM 50B compact KV NCCL crash | Non-power-of-2 slot size causes vectorized load overflow | Fixed: use 64B padded slots |
- Fork the repo and create a feature branch
- Test on MI300X or MI355X with ROCm 7.0+
- Benchmark with the provided scripts in
benchmarks/ - Submit a PR with test results
For questions, open an issue or reach out to @andyluo7.
@inproceedings{turboquant2026,
title={TurboQuant: Online Vector Quantization for KV Cache Quantization},
year={2026},
booktitle={International Conference on Learning Representations (ICLR)},
url={https://arxiv.org/abs/2504.19874}
}- TurboQuant paper (ICLR 2026)
- vLLM — LLM serving engine
- AITER — AMD inference engine with FP8 paged attention
- llama.cpp — C/C++ LLM inference
MIT — see LICENSE.