Skip to content

Add TurboQuant KV cache compression with native Metal SDPA kernel#3328

Open
arozanov wants to merge 5 commits intoml-explore:mainfrom
arozanov:feature/turboquant-kv-cache
Open

Add TurboQuant KV cache compression with native Metal SDPA kernel#3328
arozanov wants to merge 5 commits intoml-explore:mainfrom
arozanov:feature/turboquant-kv-cache

Conversation

@arozanov
Copy link
Copy Markdown

@arozanov arozanov commented Mar 28, 2026

Proposed changes

Adds TurboQuant (arXiv 2504.19874) as a native Metal SDPA kernel for KV cache compression.

  • New QuantizationMode::TurboQuant
  • sdpa_vector_turbo Metal kernel: reads 3-bit packed K indices with codebook dequant
  • Pre-rotated query optimization: no WHT butterfly in attention inner loop
  • TurboQuantSDPA primitive with full eval_gpu dispatch
  • Python API: mx.fast.turboquant_sdpa()

Benchmarks (M4 Pro 48GB, 28 query heads, 4 KV heads, D=128):

Context Apple SDPA TurboQuant Speedup
1K 0.161ms 0.108ms 1.5x
4K 0.164ms 0.109ms 1.5x
8K 0.209ms 0.106ms 2.0x
16K 0.511ms 0.104ms 4.9x

TurboQuant reads 3-bit packed data (4.6x less memory bandwidth than fp16). Kernel time constant at ~0.1ms regardless of context length.

Related: standalone package at https://github.com/arozanov/turboquant-mlx and mlx-lm PR at ml-explore/mlx-lm#1067

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Adds TurboQuant (arXiv 2504.19874) as a new quantization mode for
KV cache compression in MLX core.

Changes:
- QuantizationMode::TurboQuant enum + string conversion
- sdpa_vector_turbo Metal kernel: reads bit-packed uint32 K indices
  with codebook dequant, pre-rotated query optimization (no WHT
  in inner loop). Instantiated for fp16/bf16 x 64/128 dim x 3/4 bit.
- C++ dispatch function sdpa_vector_turbo() in SDPA backend
- Python binding mx.fast.turboquant_sdpa()
- CMake fix: removed -sdk macosx from xcrun metal invocation
  (Metal Toolchain installed via xcodebuild -downloadComponent)

Status: Metal kernel compiled and instantiated. C++ dispatch ready.
Python binding exposed. Currently falls back to regular SDPA —
full native dispatch needs TurboQuantSDPA Primitive subclass
to wire eval_gpu to the turbo kernel.
- TurboQuantSDPA primitive class in fast_primitives.h
- eval_gpu() routes to sdpa_vector_turbo Metal kernel
- Full pipeline: Python mx.fast.turboquant_sdpa() → C++ → Metal
- Pre-rotated query: no WHT butterfly in attention inner loop
- Kernel reads bit-packed uint32 K indices + codebook directly
Native Metal kernel benchmarks:
  256 tokens: 0.83x standard SDPA
  1K tokens:  0.71x (turbo faster)
  4K tokens:  0.49x (turbo 2x faster)

TurboQuant reads 3-bit packed data = less memory bandwidth than fp16.
Native Metal kernel benchmarks (28 query heads, 4 KV heads, D=128):
  256 tokens:  0.8x (overhead)
  1K tokens:   1.5x faster
  4K tokens:   1.5x faster
  8K tokens:   2.0x faster
  16K tokens:  4.9x faster

TurboQuant kernel stays at ~0.1ms regardless of context length.
Apple SDPA grows linearly with context (memory bandwidth limited).

Changes:
- Proper buffer allocation with donation in eval_gpu
- Contiguous copy handling
- CPU fallback for non-GPU paths
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant