Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions mlx/backend/cpu/masked_mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ inline void segmented_mm(
}
}

inline void gather_mm_one_row(
const float* a,
const float* b,
float* out,
size_t N,
size_t K,
size_t lda,
size_t ldb,
bool a_transposed,
bool b_transposed) {
for (size_t n = 0; n < N; n++) {
float sum = 0.0f;
for (size_t k = 0; k < K; k++) {
auto a_val = a_transposed ? a[k * lda] : a[k];
auto b_val = b_transposed ? b[n * ldb + k] : b[k * ldb + n];
sum += a_val * b_val;
}
out[n] = sum;
}
}

} // namespace

void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
Expand Down Expand Up @@ -468,21 +489,31 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];

cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out_ptr + matrix_stride_out * i,
ldc);
auto ai = a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A);
auto bi = b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B);
auto ci = out_ptr + matrix_stride_out * i;

if (M == 1) {
// Avoid intermittent BLAS NaNs for one-row gathered CPU matmuls.
gather_mm_one_row(
ai, bi, ci, N, K, lda, ldb, a_transposed, b_transposed);
} else {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
ai,
lda,
bi,
ldb,
0.0f, // beta
ci,
ldc);
}
}
});
encoder.add_temporaries(std::move(temps));
Expand Down
71 changes: 71 additions & 0 deletions python/tests/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,77 @@ def gather_mm_test(a, b, rhs):
c2 = gather_mm_test(a, b, rhs)
self.assertTrue(mx.allclose(c1, c2, rtol=tol, atol=tol))

def test_gather_mm_cpu_m1_quantized_warmup(self):
if mx.default_device() != mx.cpu:
self.skipTest("CPU only")

def make_case(L, K, D, E, I, transpose, mode):
group_size = None if mode != "affine" else 64
K, D = (K, D) if transpose else (D, K)
key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3)
indices = (mx.random.uniform(shape=(L, I), key=k1) * E).astype(
mx.uint32
)
x = (mx.random.normal((L, 1, 1, K), key=k2) / K**0.5).astype(
mx.float32
)
w = (
mx.random.normal(
(E, D, K) if transpose else (E, K, D), key=k3
)
/ K**0.5
).astype(mx.float32)
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, mode=mode)
else:
qw, s = mx.quantize(w, mode=mode)
b = None
w = mx.dequantize(qw, s, b, group_size=group_size, mode=mode)
if transpose:
w = w.swapaxes(-1, -2)
return x, w, indices, qw, s, b, group_size

def warmup_gather_qmm(L, K, D, E, I, transpose, mode):
x, w, indices, qw, s, b, group_size = make_case(
L, K, D, E, I, transpose, mode
)
mx.eval(
mx.gather_mm(x, w, rhs_indices=indices),
mx.gather_qmm(
x,
qw,
s,
b,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices,
),
)

# Keep the quantized warmup sequence: the CPU NaN repro is stateful.
warmups = [
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(32, 512, 544, 4, 2, True, "nvfp4"),
(32, 512, 544, 4, 2, True, "mxfp8"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(64, 512, 544, 4, 2, False, "nvfp4"),
]
for params in warmups:
warmup_gather_qmm(*params)

x, w, indices, *_ = make_case(64, 512, 544, 4, 2, False, "mxfp8")
expected = x @ w[indices]
actual = mx.gather_mm(x, w, rhs_indices=indices)
self.assertTrue(mx.isfinite(actual).all().item())
self.assertTrue(mx.allclose(expected, actual, rtol=1e-5, atol=1e-5).item())

def test_gather_mm_sorted_vjp(self):
def gather_mm_ref(a, b, rhs):
b = b[rhs]
Expand Down