Chunked full-attention SDPA for long key sequences#3307
Chunked full-attention SDPA for long key sequences#3307Thump604 wants to merge 7 commits intoml-explore:mainfrom
Conversation
Establishes a baseline before kernel changes: verifies mx.fast.scaled_dot_product_attention output against a float32 reference across all target configurations (float16/bfloat16/float32, head_dims 64/80/128/256, causal, cross-attention, GQA, long-context 8K, batched). Also validates the reference logsumexp computation (plain/causal/GQA) that the chunked merge reduction in later tasks will depend on. Note: float32+D=256 is skipped in two tests — that combination exceeds the Metal threadgroup memory limit (53760 > 32768 bytes) on the current kernel and is the primary motivation for the chunked SDPA implementation.
Add logsumexp output support to the fused SDPA Metal kernel: - function_constant(304) for output_logsumexp (compile-time elimination) - buffer(8) for lse_out, conditional on output_logsumexp - Per-row LSE write using existing max_score/sum_score registers When output_logsumexp=false (current default), the kernel is identical to the previous version — the function constant eliminates all new code at compile time. Zero additional compute when disabled. LSE formula: max_score * M_LN2_F + log(sum_score) converts from internal log2 space to natural log space.
- Remove forced fallback for output_logsumexp in use_fallback() - Add output_logsumexp function_constant(304) to pipeline cache hash - Bind logsumexp output buffer(8) when requested - Skip NAX path when logsumexp is needed (NAX lacks support) - Allocate and pass LSE output array from eval_gpu full attention path - Exclude vector kernels from logsumexp support in use_fallback
TDD tests for the chunked SDPA dispatch (Task 4). Tests cover dtype sweep, head dim sweep (including float32+D=256 which is the primary motivation), causal/non-causal, GQA 4:1 and 16:1, edge cases (kL==threshold, tail chunk of 1 token, three unequal chunks), small-qL prefill-step scenario, batch>1, output shape/dtype preservation, and pure-Python chunk merge identity proofs. Uses MLX_SDPA_CHUNK_THRESHOLD=1024 / MLX_SDPA_CHUNK_SIZE=512 env vars to force chunking at short sequences. Tests will fail until Task 6 implements the chunked dispatch — that is expected and correct.
Add Metal kernel that combines per-chunk SDPA outputs using online logsumexp reweighting. Each thread handles one output element (D, qL, B*H grid). Accumulates in float32 for precision, uses int64_t indexing to avoid overflow, and supports stride-based BHLD/BLHD transposition via O_strides. Instantiated for float32, float16, and bfloat16.
Split K/V into chunks along the sequence dimension, dispatch steel_attention per chunk with output_logsumexp=true, then merge via sdpa_chunked_reduce kernel. Prevents GPU watchdog timeouts at 65K+ keys. - Env-var configurable: MLX_SDPA_CHUNK_THRESHOLD (default 65536), MLX_SDPA_CHUNK_SIZE (default 32768) - Correct causal offset per chunk (preserves absolute positions) - Sinks applied to chunk 0 only - NaN guard in reduce kernel: skip zero-weight chunks where all keys are causally masked (0 * NaN = NaN in IEEE 754) - Tests updated: float32+D=256 cases skipped (pre-existing 32KB threadgroup memory limit, not chunking-related)
…g fix TestSDPAChunkedIntegration exercises the production chunked-dispatch path (no env var overrides) at sequence lengths that previously killed the GPU watchdog: 128K non-causal, 128K causal, 128K D=256 (Qwen3.5), and 256K non-causal. All 4 pass, confirming the chunked SDPA dispatch eliminates the watchdog issue end-to-end.
Independent Validation on M3 Ultra 256GBI've validated the chunked SDPA implementation on different hardware from the original issue report: Test Hardware:
Test Results:
Test Script: import mlx.core as mx
def test_long_context_sdpa(seq_len, head_dim=128):
B, H, D = 1, 8, head_dim
q = mx.random.normal((B, H, seq_len, D))
k = mx.random.normal((B, H, seq_len, D))
v = mx.random.normal((B, H, seq_len, D))
import time
start = time.time()
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0 / (D ** 0.5))
mx.eval(out)
elapsed = time.time() - start
print(f"✅ {seq_len//1000}K context, D={head_dim}: {elapsed:.2f}s")
return out
# All tests passed
test_long_context_sdpa(65 * 1024)
test_long_context_sdpa(128 * 1024)
test_long_context_sdpa(128 * 1024, head_dim=256)
test_long_context_sdpa(262 * 1024)Key Findings:
The chunked SDPA approach works perfectly on M3 Ultra. This validates the fix across different Apple Silicon generations (M2 Ultra → M3 Ultra) and extends testing to 262K context (beyond M2 Ultra's 128GB memory limit). Ready for merge! 🎯 |
Validation Report: MLX PR #3307PR: Chunked full-attention SDPA for long key sequences Summary✅ LGTM - All tests pass including 128K and 256K integration tests. No regressions detected. Chunked SDPA enables previously impossible long-context inference on Apple Silicon. Test ResultsChunked SDPA TestsFile: Coverage validated:
Long Context Integration Tests:
Skipped tests: 2 (float32 + head_dim=256 - pre-existing Metal 32KB limit, documented in PR) Regression CheckFile: No regressions - Short context paths (<65536 keys) unchanged and working. Hardware Validation NotesWhy M3 Ultra 256GB matters for this PRThe Problem (from PR description):
This Hardware's Role:
What we validated:
Architecture UnderstandingChunking StrategyThreshold: How it works: This is the standard FlashAttention-2 reduction formula. Test CoverageThe test suite validates:
Recommendation✅ Ready to merge Validation status:
Impact: Enables Qwen3.5, Llama 3.1, and other 128K+ context models to use their full context window on Apple Silicon. Critical capability - Unlocks previously impossible workloads. Validated on production-class hardware (M3 Ultra 256GB). Happy to validate future long-context PRs on this hardware. 🎯 |
|
@angeloskath — hnshah validated this on M3 Ultra up to 262K context (full LGTM report posted Mar 26). This fixes the GPU watchdog timeout at 65K+ keys (#3302). Depends on #3306 (logsumexp output). Both are ready for review. |
Summary
Split key dimension into chunks when
kL >= 65536to avoid GPU compute timeout on Apple Silicon. Each chunk runs the existing fusedsteel_attentionkernel withoutput_logsumexp=true, then a reduction kernel combines chunks using logsumexp-weighted averaging.sdpa_full_self_attention_chunked()dispatch functionsdpa_chunked_reduceMetal kernel for logsumexp-weighted chunk combinationMLX_SDPA_CHUNK_THRESHOLD(default 65536) andMLX_SDPA_CHUNK_SIZE(default 32768)Depends on: #3306 (logsumexp output)
Problem
The fused
steel_attentionkernel dispatches ALL key tokens in a single Metal compute dispatch. At ~65K+ key tokens, the dispatch exceeds the macOS GPU watchdog timeout and is killed by theAGXMetaldriver.Evidence:
Impact: Models with 128K+ native context (Qwen3.5, Llama 3.1, etc.) cannot use their full context window on Apple Silicon with fused SDPA.
How it works
This is the standard FlashAttention-2 chunk reduction formula. Causal masking is correctly adjusted per chunk:
qL_off = (global_kL - qL) - k_start. Sinks applied only to chunk 0.Results
128K context on 122B model — was impossible, now works:
No regression on short contexts: 53 tests pass, 664 subtests, 0 regressions.
Known limitations
n_chunks × B × H × qL × D. In practiceqLis small (decode: 1, prefill step: 2048), so this is manageable. Streaming merge for constant memory is planned as follow-up.float32 + head_dim=256exceeds Metal's 32KB threadgroup memory limit in thesteel_attentionkernel itself — pre-existing issue, not chunking-related.Test plan
Fixes: #3302