-
Notifications
You must be signed in to change notification settings - Fork 1.6k
[BUG] SDPA NAX kernel: int16 overflow in mask col_pos for KV sequences > 32K #3360
Description
Describe the bug
mx.fast.scaled_dot_product_attention produces incorrect output (2-10x magnitude error) when the KV sequence length exceeds 32,768 and an additive mask deactivates most positions. The bug is in the NAX attention kernel only — the non-NAX variant is correct.
In mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h, lines 324-325:
const short row_pos = base_row + iq * kU;
const short col_pos = base_col + ik * kU;base_col is int (up to kL, e.g. 66048), but col_pos is short (max 32767). When base_col > 32767, the value wraps negative, causing the mask to be loaded from wrong positions.
The non-NAX variant (steel_attention.h line 345) already uses int for col_pos.
This is a similar bug as #2894 / PR #2903 (int32 overflow in mask strides), but under the NAX code path.
To Reproduce
import mlx.core as mx
import numpy as np
D_HEAD = 64
def test(cap, active, n_heads=32, n_kv=16, t=512, dtype=mx.bfloat16):
np.random.seed(0)
q = mx.array(np.random.randn(1, n_heads, t, D_HEAD).astype(np.float32)).astype(dtype)
k = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
v = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
mask = mx.full((1, 1, 1, cap), -1e4).astype(dtype)
mask[:, :, :, cap - active:] = 0.0
mx.eval(q, k, v, mask)
y_fast = mx.fast.scaled_dot_product_attention(q, k, v, scale=D_HEAD**-0.5, mask=mask)
mx.eval(y_fast)
kk = mx.repeat(k, n_heads // n_kv, axis=1) if n_heads != n_kv else k
vv = mx.repeat(v, n_heads // n_kv, axis=1) if n_heads != n_kv else v
scores = (q @ kk.transpose(0, 1, 3, 2)) * (D_HEAD**-0.5) + mask
y_ref = mx.softmax(scores.astype(mx.float32), axis=-1).astype(dtype) @ vv
mx.eval(y_ref)
yf = np.array(y_fast.astype(mx.float32))
yr = np.array(y_ref.astype(mx.float32))
ratio = np.nanmax(np.abs(yf)) / (np.max(np.abs(yr)) + 1e-10)
ok = abs(ratio - 1.0) < 0.05
status = "PASS" if ok else f"FAIL ({ratio:.3f})"
print(f" cap={cap:6d} active={active:5d} {status}")
return ok
print(f"MLX {mx.__version__}\n")
n_fail = 0
for cap in [8192, 32768, 36864, 49152, 66048]:
if not test(cap, 1024): n_fail += 1
print(f"\n{'ALL PASS' if n_fail == 0 else f'{n_fail} FAILED'}")Expected output:
cap= 8192 active= 1024 PASS
cap= 32768 active= 1024 PASS
cap= 36864 active= 1024 PASS
cap= 49152 active= 1024 PASS
cap= 66048 active= 1024 PASS
Actual output:
cap= 8192 active= 1024 PASS
cap= 32768 active= 1024 PASS
cap= 36864 active= 1024 FAIL (0.719)
cap= 49152 active= 1024 FAIL (0.219)
cap= 66048 active= 1024 FAIL (0.123)
Expected behavior
A clear and concise description of what you expected to happen.
Desktop (please complete the following information):
- MLX version: 0.31.2 (commit 6a9a121)
- Apple M5 Max
- macOS 26.3.1
Additional context
Proposed Fix (PR #3361)
Change short to int on lines 324-325:
- const short row_pos = base_row + iq * kU;
- const short col_pos = base_col + ik * kU;
+ const int row_pos = base_row + iq * kU;
+ const int col_pos = base_col + ik * kU;There are two occurrences of this pattern in the file (mask loading in the aligned and non-aligned branches).