Skip to content

[BUG] SDPA NAX kernel: int16 overflow in mask col_pos for KV sequences > 32K #3360

@Clydingus

Description

@Clydingus

Describe the bug
mx.fast.scaled_dot_product_attention produces incorrect output (2-10x magnitude error) when the KV sequence length exceeds 32,768 and an additive mask deactivates most positions. The bug is in the NAX attention kernel only — the non-NAX variant is correct.

In mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h, lines 324-325:

const short row_pos = base_row + iq * kU;
const short col_pos = base_col + ik * kU;

base_col is int (up to kL, e.g. 66048), but col_pos is short (max 32767). When base_col > 32767, the value wraps negative, causing the mask to be loaded from wrong positions.

The non-NAX variant (steel_attention.h line 345) already uses int for col_pos.

This is a similar bug as #2894 / PR #2903 (int32 overflow in mask strides), but under the NAX code path.

To Reproduce

import mlx.core as mx
import numpy as np

D_HEAD = 64

def test(cap, active, n_heads=32, n_kv=16, t=512, dtype=mx.bfloat16):
    np.random.seed(0)
    q = mx.array(np.random.randn(1, n_heads, t, D_HEAD).astype(np.float32)).astype(dtype)
    k = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
    v = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
    mask = mx.full((1, 1, 1, cap), -1e4).astype(dtype)
    mask[:, :, :, cap - active:] = 0.0
    mx.eval(q, k, v, mask)

    y_fast = mx.fast.scaled_dot_product_attention(q, k, v, scale=D_HEAD**-0.5, mask=mask)
    mx.eval(y_fast)

    kk = mx.repeat(k, n_heads // n_kv, axis=1) if n_heads != n_kv else k
    vv = mx.repeat(v, n_heads // n_kv, axis=1) if n_heads != n_kv else v
    scores = (q @ kk.transpose(0, 1, 3, 2)) * (D_HEAD**-0.5) + mask
    y_ref = mx.softmax(scores.astype(mx.float32), axis=-1).astype(dtype) @ vv
    mx.eval(y_ref)

    yf = np.array(y_fast.astype(mx.float32))
    yr = np.array(y_ref.astype(mx.float32))
    ratio = np.nanmax(np.abs(yf)) / (np.max(np.abs(yr)) + 1e-10)
    ok = abs(ratio - 1.0) < 0.05
    status = "PASS" if ok else f"FAIL ({ratio:.3f})"
    print(f"  cap={cap:6d}  active={active:5d}  {status}")
    return ok

print(f"MLX {mx.__version__}\n")
n_fail = 0
for cap in [8192, 32768, 36864, 49152, 66048]:
    if not test(cap, 1024): n_fail += 1
print(f"\n{'ALL PASS' if n_fail == 0 else f'{n_fail} FAILED'}")

Expected output:

  cap=  8192  active= 1024  PASS
  cap= 32768  active= 1024  PASS
  cap= 36864  active= 1024  PASS
  cap= 49152  active= 1024  PASS
  cap= 66048  active= 1024  PASS

Actual output:

  cap=  8192  active= 1024  PASS
  cap= 32768  active= 1024  PASS
  cap= 36864  active= 1024  FAIL (0.719)
  cap= 49152  active= 1024  FAIL (0.219)
  cap= 66048  active= 1024  FAIL (0.123)

Expected behavior
A clear and concise description of what you expected to happen.

Desktop (please complete the following information):

  • MLX version: 0.31.2 (commit 6a9a121)
  • Apple M5 Max
  • macOS 26.3.1

Additional context

Proposed Fix (PR #3361)
Change short to int on lines 324-325:

-          const short row_pos = base_row + iq * kU;
-          const short col_pos = base_col + ik * kU;
+          const int row_pos = base_row + iq * kU;
+          const int col_pos = base_col + ik * kU;

There are two occurrences of this pattern in the file (mask loading in the aligned and non-aligned branches).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions