Skip to content

Fix CPU GatherMM one-row NaNs#3563

Open
TyceHerrman wants to merge 1 commit into
ml-explore:mainfrom
TyceHerrman:fix/test_gather_qmm_sorted-aarch64-darwin
Open

Fix CPU GatherMM one-row NaNs#3563
TyceHerrman wants to merge 1 commit into
ml-explore:mainfrom
TyceHerrman:fix/test_gather_qmm_sorted-aarch64-darwin

Conversation

@TyceHerrman
Copy link
Copy Markdown

@TyceHerrman TyceHerrman commented May 19, 2026

Fix CPU GatherMM one-row NaNs

Fixes #3200.

This fixes an intermittent CPU-only GatherMM failure that can produce NaNs after quantized gather warmup on no-Metal aarch64-darwin builds.

I reproduced the intermittent NaN locally with a CPU-only build and the quantized test loop now passes 30 consecutive runs.

The CPU GatherMM path now handles M == 1 gathered matmuls with a small local one-row helper instead of routing that shape through BLAS. The regular cblas_sgemm path remains unchanged for larger M.

Changes:

  • Add gather_mm_one_row(...) for one-row CPU gathered matmuls.
  • Use that helper only when M == 1 in GatherMM::eval_cpu.
  • Add a focused CPU regression covering the stateful quantized warmup sequence that reproduced the NaN before this fix.

Testing:

CMAKE_ARGS='-DMLX_BUILD_METAL=OFF' PIP_CACHE_DIR=/private/tmp/codex-pip-cache python -m pip install -e . --no-build-isolation -q
PYTHONPATH=python/tests DEVICE=cpu python -m unittest test_blas.TestBlas.test_gather_mm_cpu_m1_quantized_warmup
for i in $(seq 1 30); do PYTHONPATH=python/tests DEVICE=cpu python -m unittest test_quantized.TestQuantized.test_gather_qmm_sorted || exit 1; done
PYTHONPATH=python/tests DEVICE=cpu python -m unittest test_blas test_quantized
git diff --check

AI Assistance Disclosure: Codex assisted with debugging, implementation, and PR draft preparation. The changes were reviewed and tested locally by me before submission.

@booxter

Avoid the BLAS path for one-row CPU GatherMM calls because it can intermittently produce NaNs after quantized gather warmup on no-Metal aarch64-darwin builds. Add a focused regression covering that warmup sequence.

Co-authored-by: Codex <noreply@openai.com>
@TyceHerrman TyceHerrman marked this pull request as ready for review May 19, 2026 03:06
@angeloskath
Copy link
Copy Markdown
Member

I can't reproduce the error. Could you provide a reproduction. I ran the test in the diff of course and it passes without the PR.

@TyceHerrman
Copy link
Copy Markdown
Author

I reproduced the failure locally with the full nixpkgs package build/test path rather than by running the added regression test in isolation.

I tested nixpkgs python313Packages.mlx on aarch64-darwin with this MLX checkout as src, so it uses the nixpkgs package recipe, including CPU-only MLX (MLX_BUILD_METAL=false) and the package pytest suite.

Base upstream/main at 2414e5df6a8cbc0d03f978d7212229f626f1f23b failed. That SHA is the unpatched MLX main commit I used as the base comparison:

SUBFAILED(L=64, K=512, D=544, E=4, I=2, transpose=False, mode='nvfp4') python/tests/test_quantized.py::TestQuantized::test_gather_qmm_sorted
AssertionError: array(nan, dtype=float32) not less than 1.5e-05
===== 1 failed, 671 passed, 40 skipped, 1 deselected in 136.77s (0:02:16) =====

This PR at a01171bd3537b01b475c01d29affce324fd36093 passed the same nixpkgs package build/test:

===== 672 passed, 40 skipped, 1 deselected in 129.63s (0:02:09) =====
nix build --impure -L --keep-failed \
  --argstr rev "$REV" \
  --out-link /private/tmp/mlx-result \
  -f /private/tmp/mlx-local-src-test.nix

with this local expression:

{ rev }:

let
  pkgs = import <nixpkgs> {};
  src = builtins.fetchGit {
    url = "<local-path>";
    inherit rev;
    allRefs = true;
  };
in
  pkgs.python313Packages.mlx.overridePythonAttrs (old: {
    inherit src;
  })

I think added regression test in isolation is not enough - running only that test can pass on unpatched main. The failure appears to be state/order/environment sensitive. The full nixpkgs package test run reproduces the existing test_gather_qmm_sorted NaN on base, and the same package build/test passes with this PR.

@angeloskath
Copy link
Copy Markdown
Member

I understand that but this doesn't mean that the error is in gather mm unfortunately. It could be an out of bounds write in random number generation or something equivalent. I think unless you can reproduce with static inputs we shouldn't merge this PR.

Sorry for that, it's just probably not where the error is.

@TyceHerrman
Copy link
Copy Markdown
Author

The fundamental issue appears to be an Accelerate cblas_sgemm M == 1 edge case

In the failing GatherMM call, BLAS operands:

  • A: selected left-hand row from x
  • B: selected right-hand matrix from w
  • C: output row buffer

MLX trace shows that A and B are finite, A/B/C do not overlap in memory, and B is unchanged by the call.

With temporary local instrumentation in GatherMM, I also ran a preflight CBLAS call immediately before the real one. That preflight used the same GatherMM call site, same selected x row (A), same selected w matrix (B), and same CBLAS arguments, but wrote into a fresh clean scratch output buffer instead of the real MLX output (C). That scratch call passed; the real production call failed only when the selected (C) output lane already contained NaN.

That matches the standalone Accelerate repro. Pure cblas_sgemm with M=1,N=512,K=544, beta=0, misaligned B (mod64 != 0), and C prefilled with NaN returns NaN in tail output columns. With beta=0, BLAS semantics say previous C contents should not affect the result.

So the fix could be to avoid this Accelerate M == 1 SGEMM path in MLX. Could use dedicated one-row dot/GEMV-style path, rather than depending on beta=0 to ignore stale output contents.

if that seems like the correct route to take, happy to work on that here or in new pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] test_gather_qmm_sorted consistently fails in some aarch64-darwin environments

2 participants