diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 688479c602..4552d5c5bb 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -105,6 +105,27 @@ inline void segmented_mm( } } +inline void gather_mm_one_row( + const float* a, + const float* b, + float* out, + size_t N, + size_t K, + size_t lda, + size_t ldb, + bool a_transposed, + bool b_transposed) { + for (size_t n = 0; n < N; n++) { + float sum = 0.0f; + for (size_t k = 0; k < K; k++) { + auto a_val = a_transposed ? a[k * lda] : a[k]; + auto b_val = b_transposed ? b[n * ldb + k] : b[k * ldb + n]; + sum += a_val * b_val; + } + out[n] = sum; + } +} + } // namespace void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { @@ -468,21 +489,31 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { uint32_t indx_B = rhs_indices_ptr[elem_to_loc( i, rhs_indices_shape, rhs_indices_strides)]; - cblas_sgemm( - CblasRowMajor, - a_transposed ? CblasTrans : CblasNoTrans, // transA - b_transposed ? CblasTrans : CblasNoTrans, // transB - M, - N, - K, - 1.0f, // alpha - a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A), - lda, - b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B), - ldb, - 0.0f, // beta - out_ptr + matrix_stride_out * i, - ldc); + auto ai = a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A); + auto bi = b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B); + auto ci = out_ptr + matrix_stride_out * i; + + if (M == 1) { + // Avoid intermittent BLAS NaNs for one-row gathered CPU matmuls. + gather_mm_one_row( + ai, bi, ci, N, K, lda, ldb, a_transposed, b_transposed); + } else { + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + 1.0f, // alpha + ai, + lda, + bi, + ldb, + 0.0f, // beta + ci, + ldc); + } } }); encoder.add_temporaries(std::move(temps)); diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index dedfa5d4fb..81bc18a642 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1264,6 +1264,77 @@ def gather_mm_test(a, b, rhs): c2 = gather_mm_test(a, b, rhs) self.assertTrue(mx.allclose(c1, c2, rtol=tol, atol=tol)) + def test_gather_mm_cpu_m1_quantized_warmup(self): + if mx.default_device() != mx.cpu: + self.skipTest("CPU only") + + def make_case(L, K, D, E, I, transpose, mode): + group_size = None if mode != "affine" else 64 + K, D = (K, D) if transpose else (D, K) + key = mx.random.key(0) + k1, k2, k3 = mx.random.split(key, 3) + indices = (mx.random.uniform(shape=(L, I), key=k1) * E).astype( + mx.uint32 + ) + x = (mx.random.normal((L, 1, 1, K), key=k2) / K**0.5).astype( + mx.float32 + ) + w = ( + mx.random.normal( + (E, D, K) if transpose else (E, K, D), key=k3 + ) + / K**0.5 + ).astype(mx.float32) + if mode == "affine": + qw, s, b = mx.quantize(w, group_size=group_size, mode=mode) + else: + qw, s = mx.quantize(w, mode=mode) + b = None + w = mx.dequantize(qw, s, b, group_size=group_size, mode=mode) + if transpose: + w = w.swapaxes(-1, -2) + return x, w, indices, qw, s, b, group_size + + def warmup_gather_qmm(L, K, D, E, I, transpose, mode): + x, w, indices, qw, s, b, group_size = make_case( + L, K, D, E, I, transpose, mode + ) + mx.eval( + mx.gather_mm(x, w, rhs_indices=indices), + mx.gather_qmm( + x, + qw, + s, + b, + group_size=group_size, + mode=mode, + transpose=transpose, + rhs_indices=indices, + ), + ) + + # Keep the quantized warmup sequence: the CPU NaN repro is stateful. + warmups = [ + (32, 512, 512, 4, 2, True, "affine"), + (32, 512, 544, 4, 2, True, "mxfp4"), + (32, 512, 544, 4, 2, True, "nvfp4"), + (32, 512, 544, 4, 2, True, "mxfp8"), + (133, 512, 512, 4, 2, True, "affine"), + (133, 512, 555, 4, 2, True, "affine"), + (133, 512, 512, 4, 2, True, "affine"), + (64, 512, 512, 4, 2, False, "affine"), + (64, 512, 544, 4, 2, False, "mxfp4"), + (64, 512, 544, 4, 2, False, "nvfp4"), + ] + for params in warmups: + warmup_gather_qmm(*params) + + x, w, indices, *_ = make_case(64, 512, 544, 4, 2, False, "mxfp8") + expected = x @ w[indices] + actual = mx.gather_mm(x, w, rhs_indices=indices) + self.assertTrue(mx.isfinite(actual).all().item()) + self.assertTrue(mx.allclose(expected, actual, rtol=1e-5, atol=1e-5).item()) + def test_gather_mm_sorted_vjp(self): def gather_mm_ref(a, b, rhs): b = b[rhs]