Environment
- MLX 0.31.2
- macOS, Apple M4
- Python 3.12
Summary
mx.fast.scaled_dot_product_attention returns mathematically incorrect results compared to manual attention, causing complete model failure (0% accuracy). The manual fallback produces correct results (99.5% accuracy), matching a Candle/Rust implementation with identical weights.
Model config
- seq_len=82 (81 Sudoku tokens + 1 puzzle embedding), hidden=512, heads=8, head_dim=64
- 4 H-layers + 4 L-layers, H_cycles=2, L_cycles=2
- RoPE (base=10000) with pre-computed cos/sin tables shape [82, 64]
- No GQA (num_heads == num_kv_heads)
- Weights: sapientinc/HRM-checkpoint-sudoku-extreme
Behavior
MPS fused kernel: model produces garbage output, 0% accuracy on Sudoku given cells. All 16 ACT steps run (q-values never converge).
Manual attention: model produces correct Sudoku solutions, 99.5% accuracy (214/215 given cells across 10 puzzles), halts in 2-7 steps.
Reproduction
The fix was replacing the MPS fused call with the manual fallback that already exists in the code:
# Broken:
return mx.fast.scaled_dot_product_attention(query, key, value, scale=scale)
# Fixed:
q = mx.transpose(query, (0, 2, 1, 3))
k = mx.transpose(key, (0, 2, 1, 3))
v = mx.transpose(value, (0, 2, 1, 3))
scores = q @ mx.transpose(k, (0, 1, 3, 2)) * scale
attn_weights = mx.softmax(scores, axis=-1)
out = attn_weights @ v
return mx.transpose(out, (0, 2, 1, 3))
Full investigation log with side-by-side comparison against Candle/Rust available on request.
Workaround
Use manual attention unconditionally — the MPS fused kernel produces wrong attention weights for this model configuration, cascading into zero accuracy.
Environment
Summary
mx.fast.scaled_dot_product_attentionreturns mathematically incorrect results compared to manual attention, causing complete model failure (0% accuracy). The manual fallback produces correct results (99.5% accuracy), matching a Candle/Rust implementation with identical weights.Model config
Behavior
MPS fused kernel: model produces garbage output, 0% accuracy on Sudoku given cells. All 16 ACT steps run (q-values never converge).
Manual attention: model produces correct Sudoku solutions, 99.5% accuracy (214/215 given cells across 10 puzzles), halts in 2-7 steps.
Reproduction
The fix was replacing the MPS fused call with the manual fallback that already exists in the code:
Full investigation log with side-by-side comparison against Candle/Rust available on request.
Workaround
Use manual attention unconditionally — the MPS fused kernel produces wrong attention weights for this model configuration, cascading into zero accuracy.