[CUDA] New 4bit GEMM kernels for inference#1949
Open
matthewdouglas wants to merge 7 commits into
Open
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This was referenced May 15, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
TL;DR
4-bit inference is up to 4x faster at batch serving sizes (i.e. 2-64) across GPUs from Turing through Blackwell. Gains are largest at smaller batches. In many cases the new kernel is faster at batch size of 1 also. We also see improvements when using nested quantization, where the fusion in the new kernels provides an advantage.
Summary
This PR adds new custom fused 4-bit dequantize+GEMM kernels. It is intended to replace both the existing batch size 1 path (GEMV kernel) and the
dequantize + F.linearpath for small-to-medium batch sizes. The new kernels additionally fuse the nested absmax dequantization and optional bias addition.New kernels
There are three new kernel tiers, selected automatically at runtime via heuristics calibrated with benchmarks on T4, A10, A100, RTX 4090, L4, L40S, H100, H200, B200, and RTX PRO 6000. The heuristics are expected to generalize well across most hardware on sm75+. Note that while the SIMT kernel does compile and should work on Pascal and Volta, we did not perform benchmarks there and will dispatch to it more conservatively.
Claude Code was used to create a simulation tool to help calibrate these heuristics. We lean towards conservatively dispatching the new kernel to avoid any performance regressions, so across many shapes we may miss some wins by falling back earlier than strictly necessary. Across all of the GPU testing on over 40 shapes and batch sizes from 1 up through 2048, we make a regressive decision less than 1% of the time, and make the most ideal decision over 80% of the time. For the remaining decisions, we simply left performance on the table by either falling back too early, or selecting a kernel where we had a faster alternative available.
m16n8k8m16n8k16The dispatch system considers problem dimensions, GPU architecture, and SM count when factoring in the decisions on which kernel to launch. Within each of the MMA kernels, there are several tiling configurations for the dispatcher to choose from. For larger batch sizes, we will automatically fall back to the existing
dequantize + F.linearpath.New custom op layer
Added a new custom operator
bitsandbytes::gemm_4bitto abstract this new operator. It is backwards compatible with other backends, i.e. they will continue to launch their existing GEMV implementations. Note that this operator is not intended as a public API and can change in the future. It serves as an extension point to add implementations for additional hardware.Additional cleanup
bitsandbytes.matmul_4bitnow normalizes packed weights to a canonical shape internally, so callers no longer need to pass the quantized weight tensor in any particular orientation. For weights quantized in[out_features, in_features]orientation passing a.t()of the quantized weight tensor is fine, although no longer required, and will produce correct results. However, weights that were quantized in transposed[in_features, out_features]orientation will now emit aDeprecationWarning. Support for this is likely to be dropped in the future. This is not an expected or typical use case.End-to-End Benchmark Results
All runs: NF4, input_len=128, output_len=128, bitsandbytes 0.49.2 (stable) vs this PR (new).
PyTorch: 2.10.0+cu130 in eager mode. Transformers: 5.7.0.
All hardware except RTX 4090 was hosted on Modal.com.
TPOT = time per output token (decode step), lower is better.