Skip to content

[CUDA] Add GatherQMM for quantized gather matmul#3321

Open
Lyxot wants to merge 5 commits intoml-explore:mainfrom
Lyxot:cuda/gather_qmm
Open

[CUDA] Add GatherQMM for quantized gather matmul#3321
Lyxot wants to merge 5 commits intoml-explore:mainfrom
Lyxot:cuda/gather_qmm

Conversation

@Lyxot
Copy link
Copy Markdown
Contributor

@Lyxot Lyxot commented Mar 25, 2026

Proposed Changes

Adds CUDA GatherQMM implementation, enabling quantized MoE inference on GPU. Also extracts qmv_impl device function from qmv_kernel, so both qmv_kernel and gather_qmv_kernel share the same dequant+FMA+reduce compute body.

What's supported

  • transpose=true
  • All quantization modes: affine (2/3/4/5/6/8-bit), mxfp4, mxfp8, nvfp4
  • All float types: float32, float16, bfloat16
  • Arbitrary batch dimensions and multi-dimensional gather indices

What's next

  • Dedicated gather_fp_qmv to inherit architecture-specific FP-mode optimizations from fp_qmv
  • Matrix-matrix path for large M using CUTLASS-based gather_qmm kernel
  • Sorted-rhs batching with grouped GEMM optimization for sorted expert indices

Performance

4-bit affine quantization, M=1:

N K Experts Batch Loop (µs) GatherQMM (µs) Speedup
4096 4096 8 2 233 91 2.6x
4096 4096 8 8 894 252 3.5x
4096 4096 8 32 3559 1000 3.6x
11008 4096 8 32 8012 2642 3.0x

Loop baseline: individual quantized_matmul calls per expert, stacked.

Copilot AI review requested due to automatic review settings March 25, 2026 17:06
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a CUDA implementation path for GatherQMM (quantized gather + matmul), enabling GPU execution for MoE-style gathered expert matmuls. It also refactors the existing qmv CUDA kernel to share a common dequantize+FMA+reduce compute body with the new gather kernel.

Changes:

  • Add GatherQMM::eval_gpu CUDA dispatch that routes to a new gather_qmv kernel.
  • Extract shared device compute body (qmv_impl) from qmv_kernel and reuse it in gather_qmv_kernel.
  • Expose gather_qmv in the CUDA quantized qmm header and enable GPU support for GatherQMM.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
mlx/backend/cuda/quantized/quantized.cpp Adds CUDA GatherQMM::eval_gpu implementation that allocates output and launches gather_qmv.
mlx/backend/cuda/quantized/qmm/qmv.cu Refactors qmv compute into qmv_impl and adds the new gather_qmv_kernel + host launcher.
mlx/backend/cuda/quantized/qmm/qmm.h Declares the new gather_qmv entry point.
mlx/backend/cuda/primitives.cpp Removes NO_GPU(GatherQMM) to allow CUDA backend execution.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Lyxot added 2 commits March 26, 2026 10:46
Broadcast index arrays (e.g., from default lhs_indices) may have
stride-0 dimensions backed by fewer elements than the logical shape.
The gather_qmv kernel reads indices linearly after collapse, causing
out-of-bounds access. Fix by copying to contiguous layout first.
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants