diff --git a/CMakeLists.txt b/CMakeLists.txt index ea61fbdf3..d9c691e79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,29 @@ if(BUILD_CUDA) list(APPEND SRC_FILES ${GPU_FILES}) + # 4-bit GEMM SIMT and dispatch compile for all selected archs. + # sm75/sm80+ MMA files only compile if those archs are selected. + set(_cc_all ${COMPUTE_CAPABILITY} ${_LATEST_CAPABILITY}) + list(APPEND SRC_FILES csrc/gemm_4bit_simt.cu csrc/gemm_4bit.cu) + + if(75 IN_LIST _cc_all) + # Builds only on sm75 + list(APPEND SRC_FILES csrc/gemm_4bit_sm75.cu) + add_compile_definitions(BNB_HAS_GEMM4BIT_SM75) + endif() + + set(_cc_sm80plus) + foreach(_cc IN LISTS _cc_all) + if(_cc GREATER_EQUAL 80) + list(APPEND _cc_sm80plus ${_cc}) + endif() + endforeach() + if(_cc_sm80plus) + # Builds only on sm80+ + list(APPEND SRC_FILES csrc/gemm_4bit_sm80.cu) + add_compile_definitions(BNB_HAS_GEMM4BIT_SM80) + endif() + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) elseif(BUILD_HIP) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index fcfd0a03f..9b97480aa 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -220,6 +220,65 @@ def _( return out, absmax +torch.library.define( + "bitsandbytes::gemm_4bit", + "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, int blocksize, str quant_type, " + "Tensor? bias=None, Tensor? absmax_8bit=None, Tensor? absmax_code=None, Tensor? absmax_offset=None) -> Tensor", +) + + +@register_fake("bitsandbytes::gemm_4bit") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + bias: Optional[torch.Tensor] = None, + absmax_8bit: Optional[torch.Tensor] = None, + absmax_code: Optional[torch.Tensor] = None, + absmax_offset: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(len(shapeB) == 2, lambda: f"shapeB must be 2D [N, K], got {list(shapeB)}") + torch._check(A.shape[-1] == shapeB[1], lambda: f"A inner dim ({A.shape[-1]}) must match shapeB ({shapeB[1]})") + torch._check( + A.dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32), + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(blocksize in [32, 64, 128, 256, 512, 1024, 2048, 4096], lambda: f"invalid blocksize {blocksize}") + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + if absmax_8bit is not None: + torch._check(absmax_8bit.ndim == 1, lambda: f"absmax_8bit must be 1D, got {absmax_8bit.ndim}D") + torch._check(absmax_8bit.dtype == torch.uint8, lambda: f"absmax_8bit must be uint8, got {absmax_8bit.dtype}") + torch._check(absmax_code is not None, lambda: "absmax_code required when absmax_8bit is provided") + torch._check(absmax_code.ndim == 1, lambda: f"absmax_code must be 1D, got {absmax_code.ndim}D") + torch._check( + absmax_code.shape[0] == 256, lambda: f"absmax_code must have 256 entries, got {absmax_code.shape[0]}" + ) + torch._check( + absmax_code.dtype == torch.float32, lambda: f"absmax_code must be float32, got {absmax_code.dtype}" + ) + torch._check(absmax_offset is not None, lambda: "absmax_offset required when absmax_8bit is provided") + torch._check( + absmax_offset.ndim == 0, lambda: f"absmax_offset must be a scalar (0-dim), got {absmax_offset.ndim}D" + ) + torch._check( + absmax_offset.dtype == torch.float32, lambda: f"absmax_offset must be float32, got {absmax_offset.dtype}" + ) + if bias is not None: + torch._check(bias.ndim == 1, lambda: f"bias must be 1D, got {bias.ndim}D") + torch._check(bias.shape[0] == shapeB[0], lambda: f"bias length ({bias.shape[0]}) must match N ({shapeB[0]})") + torch._check(bias.dtype == A.dtype, lambda: f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})") + N = shapeB[0] + return torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device) + + torch.library.define( "bitsandbytes::dequantize_blockwise", "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor", diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 95a7d9090..e254f63df 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -315,9 +315,37 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] else: return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - # 1. Dequantize - # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + # Normalize to canonical [(N*K+1)//2, 1]. Packed weights are always contiguous + # in this orientation (B.t() callers get strides [1,1], still compatible). + # quant_state.shape is the source of truth for N and K. + B = B.view(-1, 1) + + if not quant_state.nested: + output = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B, + quant_state.shape, + quant_state.absmax, + quant_state.blocksize, + quant_state.quant_type, + bias=bias, + ) + elif quant_state.state2.blocksize == 256: + output = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B, + quant_state.shape, + quant_state.state2.absmax, + quant_state.blocksize, + quant_state.quant_type, + bias=bias, + absmax_8bit=quant_state.absmax, + absmax_code=quant_state.state2.code, + absmax_offset=quant_state.offset, + ) + else: + raise NotImplementedError("nested quantization with state2.blocksize != 256 is not supported") + if out is not None: out.copy_(output) output = out @@ -351,7 +379,9 @@ def backward(ctx, grad_output): # not supported by PyTorch. TODO: create work-around # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + # B in ctx.tensors is already in canonical [(N*K+1)//2, 1] form (normalized in forward). + # dequantize returns [N, K]; matmul(grad_output[M,N], [N,K]) = grad_A[M,K]. + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype)) return grad_A, grad_B, None, grad_bias, None @@ -381,26 +411,81 @@ def matmul_4bit( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ): - assert quant_state is not None - if A.device.type == "cpu": - if getattr(quant_state, "packing_format_for_cpu", False): - out = F.gemv_4bit(A, B, out, state=quant_state) - if bias is not None: - out += bias - return out - else: - return MatMul4Bit.apply(A, B, out, bias, quant_state) - - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": - if A.shape[-1] % quant_state.blocksize != 0: + if quant_state is None: + raise ValueError("quant_state is required") + if len(quant_state.shape) != 2: + raise ValueError("matmul_4bit: quant_state.shape must be 2D [N, K]") + + # packing_format_for_cpu uses a different memory layout optimized for AVX512BF16. + # This flag is only set for inference (weight conversion happens at eval time). + # The underlying kernel supports any M via tiled GEMM despite the gemv name. + if A.device.type == "cpu" and getattr(quant_state, "packing_format_for_cpu", False): + result = F.gemv_4bit(A, B, out=out, state=quant_state) + if bias is not None: + result += bias + return result + + # Normalize B to canonical [(N*K+1)//2, 1]. Packed weights are always contiguous + # in this orientation (B.t() callers get strides [1,1], still compatible). + # quant_state.shape is the source of truth for N and K. + B = B.view(-1, 1) + + K = A.shape[-1] + + # Weight is in [K, N] orientation when A's inner dim matches shape[0] not shape[1]. + # Square weights (K==N) are ambiguous and treated as [N, K]. + if K == quant_state.shape[0] and K != quant_state.shape[1]: + if not _is_compiling(): warn( - f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", + f"matmul_4bit: weight was quantized from a [K, N] tensor (quant_state.shape={list(quant_state.shape)}). " + "Re-quantize from the weight in [N, K] (out_features, in_features) orientation. " + "This will be an error in a future version.", + DeprecationWarning, + stacklevel=2, + ) + B_dq = F.dequantize_4bit(B, quant_state).to(A.dtype) + result = torch.nn.functional.linear(A, B_dq.t(), bias) + if out is not None: + out.copy_(result) + return out + return result + + needs_grad = torch.is_grad_enabled() and (A.requires_grad or (bias is not None and bias.requires_grad)) + if not needs_grad: + A_numel = A.numel() + if A_numel == 0: + if out is not None: + return out + return torch.empty((*A.shape[:-1], quant_state.shape[0]), dtype=A.dtype, device=A.device) + + if not quant_state.nested: + result = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B, + quant_state.shape, + quant_state.absmax, + quant_state.blocksize, + quant_state.quant_type, + bias=bias, + ) + elif quant_state.state2.blocksize == 256: + result = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B, + quant_state.shape, + quant_state.state2.absmax, + quant_state.blocksize, + quant_state.quant_type, + bias=bias, + absmax_8bit=quant_state.absmax, + absmax_code=quant_state.state2.code, + absmax_offset=quant_state.offset, ) - return MatMul4Bit.apply(A, B, out, bias, quant_state) else: - out = F.gemv_4bit(A, B.t(), out, state=quant_state) - if bias is not None: - out += bias + raise NotImplementedError("nested quantization with state2.blocksize != 256 is not supported") + if out is not None: + out.copy_(result) return out - else: - return MatMul4Bit.apply(A, B, out, bias, quant_state) + return result + + return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 409e0252d..835febd99 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,7 +1,9 @@ from collections.abc import Sequence import ctypes as ct +import functools from math import prod from typing import Optional +from warnings import warn import torch @@ -9,6 +11,14 @@ from ..._ops import register_kernel from ...cextension import lib +from ..default.ops import _gemm_4bit_default_impl +from ..utils import _get_4bit_code + + +@functools.cache +def _gpu_dispatch_props(device_index): + props = torch.cuda.get_device_properties(device_index) + return props.multi_processor_count, props.major, props.minor @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -543,6 +553,364 @@ def _gemv_4bit_impl( ) +@functools.cache +def _gemm_4bit_use_custom(device_index, dtype, M, N, K): + """Custom kernel vs dequant+F.linear heuristic for M in [5, 1536]. + + Per-arch notes (bf16/fp16, M >= 8, large weight): + sm75 (T4, ~300 GB/s GDDR6): fp16 MMA only; GDDR makes dequant expensive. + sm80 (A100, ~2 TB/s HBM2e): mma.sync; HBM thresholds; K-heavy shapes handled explicitly. + sm86 (A10, ~600 GB/s GDDR6): dedicated block; wider M caps than sm89 at medium N. + sm89 (4090, L40S, GDDR6X): default fallback; tall-K and large-N get higher M caps. + sm90 (H100/H200, HBM3/HBM3e): dequant+linear is much faster; thresholds are tight. + sm100 (B200/B300, HBM3e): exits early at top of function. + sm120 (RTX 5000, GDDR7): dedicated block; medium-N tiers differ from sm89. + """ + num_sms, major, minor = _gpu_dispatch_props(device_index) + n_blocks = (N + 63) // 64 + + # fp32 has no MMA kernel; pre-sm75 has no MMA kernel; sm75 has fp16 MMA only. + # For all of these, custom only wins in the SIMT range (M<8). + if dtype == torch.float32 or major < 7: + return M < 8 + if major == 7 and (minor < 5 or dtype != torch.float16): + return M < 8 + + # sm87 and sm110: no calibration data, conservative fallback. + if (major == 8 and minor == 7) or major == 11: + return False + + # sm100 (B200/B300): dequant+F.linear is significantly faster than our mma.sync kernel. + if major == 10: + if n_blocks >= num_sms * 3: + return M <= 32 + if n_blocks >= num_sms: + return False if K >= N else M <= 8 + return False + + is_sm75 = major == 7 and minor == 5 + is_sm80 = major == 8 and minor == 0 + is_sm86 = major == 8 and minor == 6 + is_sm90 = major == 9 + is_sm120 = major == 12 and minor == 0 + is_hbm = is_sm80 or is_sm90 # sm100 already returned above + tall_k_2xn = K > N * 2 + + # Small-weight path (N*K < 4MB): dequant overhead dominates. + if N * K < 4 * 1024 * 1024: + if K * 2 < N: + # Very short K (K < N/2): latency-dominated, custom 3-9x cheaper. + if is_hbm: + # Calibrated on A100: custom wins to M=1536 (low wave), M=512 (high wave). + # Calibrated on H100/H200: custom wins to M=512 (low wave), M=320 (high wave). + low_wave = n_blocks * 3 < num_sms + if is_sm80: + return M <= (1536 if low_wave else 512) + return M <= (512 if low_wave else 320) + if is_sm75: + # T4: wins require >=3 waves; M cap scales with K depth. + if n_blocks >= num_sms * 3: + return M <= 320 + if K >= 1024: + return M <= 64 + if K >= 704: + return M <= 96 + return M <= 320 + # sm86/sm89/sm120: well-subscribed wins to M=320; undersubscribed tighter. + if n_blocks >= num_sms: + return M <= 320 + return M <= 192 if n_blocks * K > num_sms * 320 else M <= 320 + # K*2 >= N: arch-specific handling at low occupancy. + quarter_wave = n_blocks * 4 <= num_sms + if is_sm80 and quarter_wave: + # A100 <1/4 wave: K>=N loses earlier (K-tiling efficient on HBM2e). + if K >= N: + return M <= (32 if n_blocks * 8 <= num_sms else 128) + return M <= 384 + # T4 <1 wave non-short-K: M>8 routes through occupancy caps below. + if is_sm75 and n_blocks < num_sms and M > 8: + return M <= 64 + # General tiers (sm90, sm86, sm89, sm120): + # GDDR tall-K (K>=N) at <1/4 wave: K-tiling in default impl wins above M=23. + if quarter_wave: + return M <= (32 if (K < N or is_hbm) else 23) + if n_blocks * 2 <= num_sms: + return M <= 16 + return False # >=1/2 wave: no validated wins for remaining small-weight shapes + + # Non-small-weight: custom wins up to M=512; dequant+F.linear wins above that. + if M > 512: + return False + + # M=5-7: custom SIMT generally wins because dequant cost dominates. + # Exceptions where K-tiling efficiency or MMA occupancy favors dequant+F.linear: + # HBM at M=6-7: tall-K (K>N) at ~3/4 MMA wave. + # sm90 square (K==N) at specific occupancy bands: arch-specific crossover. + if M < 8: + hbm_m67_thresh = 36 if is_sm90 else 48 + if is_hbm and M >= 6 and n_blocks >= hbm_m67_thresh: + lt_75pct_wave = n_blocks * 4 < num_sms * 3 + lt_60pct_wave = n_blocks * 5 < num_sms * 3 + # Tall-K: K-tiling in default impl wins when under-subscribed. + if K > N and lt_75pct_wave: + return False + # Square: arch-specific crossover around 0.6 wave. + # A100 (HBM2e): loses below 0.6 wave. H100/H200 (HBM3/3e): loses above. + if K == N: + if is_sm80 and lt_60pct_wave: + return False + if is_sm90 and lt_75pct_wave and not lt_60pct_wave: + return False + return True + + # M in [8, 512]: per-arch tier ladders. + + if is_sm75: + # fp16 MMA (m16n8k8). GDDR bandwidth makes dequant relatively expensive. + if n_blocks >= num_sms * 3: + return M <= (128 if K < N else 64) + if n_blocks >= num_sms // 2: + return M <= 64 + return M <= 32 + + if is_sm80: + # mma.sync (m16n8k16). HBM2e thresholds; K-heavy shapes handled explicitly. + if n_blocks >= num_sms * 3: + return M <= 128 + if n_blocks >= num_sms: + return M <= (64 if K < N else 32) + # Very tall-K (K>=3N) at >1/4 wave: K-tiling in default impl wins at all M. + # Uses >= to catch K==3N (e.g. N=4096,K=12288 M=9-16: measured regression on A100). + if K >= N * 3 and n_blocks * 4 > num_sms: + return False + # Square (K==N) at 0.5-1 wave: K-tiling wins at ~0.6 wave. + # n_blocks>=48 excludes small N where SIMT still wins. + if K == N and n_blocks >= 48 and n_blocks * 5 < num_sms * 3: + return False + # <0.5 wave: K<=N custom wins to M=128; K>N default wins above wave threshold. + if n_blocks * 2 < num_sms: + if K <= N: + return M <= 128 + if n_blocks * 3 >= num_sms: + return False + # 0.5-1 wave K= num_sms // 2 and K < N: + return M <= 128 + return M <= 16 + + if is_sm86: + # ~600-940 GB/s GDDR6/GDDR6X. Dedicated block: sm89 fallback tiers are too + # loose for 600 GB/s bandwidth and cause regressions at medium N (~N=4096). + if n_blocks >= num_sms: + return M <= 128 + if n_blocks >= num_sms // 2: + return M <= 64 + return M <= 16 + + if is_sm90: + # HBM3/HBM3e. dequant+F.linear (WGMMA path) is significantly faster than our + # mma.sync kernel; thresholds are calibrated conservatively (H100/H200 share path). + if n_blocks >= num_sms * 3: + return M <= 64 + if n_blocks >= num_sms * 2: + return M <= 48 + if n_blocks >= num_sms: + return M <= 32 + if n_blocks >= num_sms // 2: + # Square/tall-K at <3/4 wave: K-tiling too efficient on HBM3e. + if K >= N and n_blocks * 4 < num_sms * 3: + return False + return M <= 16 + return False + + if is_sm120: + # GDDR7 (~1-1.8 TB/s). Medium-N threshold tiers differ from sm89. + # sm121 (DGX Spark) has a different bandwidth/SM profile; uses sm89 + # fallback below until validated. + if n_blocks >= num_sms * 3: + return M <= 256 + if n_blocks >= num_sms * 2: + return M <= 128 + # Short-K (K= num_sms * 4: + return M <= (96 if K >= N else 64) + if n_blocks >= num_sms: + return M <= 64 + if n_blocks >= num_sms // 2: + # Large-N (n_blocks>=128, N>=8192) with K>=N/2: calibrated on RTX Pro 6000 to M=64. + return M <= (64 if (K * 2 >= N and n_blocks >= 128) else 8) + if tall_k_2xn and n_blocks > 64: + return M <= 16 + return M <= 8 + + # Fallback: sm89 (4090, L40S, L4), sm121 (DGX Spark), unrecognized arches. + # GDDR bandwidth makes dequant relatively expensive so custom wins at higher M. + if n_blocks >= num_sms * 3: + return M <= 256 + if n_blocks >= num_sms * 2: + return M <= 128 + # Near-wave (~0.8x): tall-K and very large N (n_blocks>=200, N>=14336) raise cap to M=128. + # N=10240 (n_blocks=160) deliberately excluded to avoid regressions there. + if n_blocks * 5 >= num_sms * 4: + if tall_k_2xn or n_blocks >= 200: + return M <= 128 + # Square/tall-K: >=60 SMs wins to M=128; <60 SMs default wins earlier. + if K >= N: + return M <= (128 if num_sms >= 60 else 32) + return M <= 64 + if n_blocks >= num_sms // 2: + if tall_k_2xn: + return M <= 64 + if n_blocks >= 64: + return M <= 8 + return M <= 32 + # Tall-K (K>N) at narrow N (n_blocks<=48): M-driven crossover. + # K>=3N (e.g. N=2560,K=10240): SIMT wins to M=12. Moderate K>N: M=10. + if K > N and n_blocks <= 48: + return M <= (12 if K >= N * 3 else 10) + return M <= (16 if (tall_k_2xn or n_blocks < 48) else 8) + + +if torch.version.hip is None: + + @register_kernel("bitsandbytes::gemm_4bit", "cuda") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + bias: Optional[torch.Tensor] = None, + absmax_8bit: Optional[torch.Tensor] = None, + absmax_code: Optional[torch.Tensor] = None, + absmax_offset: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + K = A.shape[-1] + M = A.numel() // K + N = shapeB[0] + + # M>1536: dequant+F.linear wins (dequant savings negligible at very large batch). + # M<=4: always custom (custom kernel wins universally at small batch). + # M in [5, 1536]: shape/arch-dependent; cached per (device, dtype, M, N, K). + if M > 1536: + use_custom = False + elif K % blocksize != 0: + warn( + f"inner dimension ({K}) is not aligned for fast kernel " + f"with blocksize={blocksize}, falling back to slower implementation.", + UserWarning, + ) + use_custom = False + else: + use_custom = M <= 4 or _gemm_4bit_use_custom(A.device.index, A.dtype, M, N, K) + + if not use_custom: + return _gemm_4bit_default_impl( + A, B, shapeB, absmax, blocksize, quant_type, bias, absmax_8bit, absmax_code, absmax_offset + ) + + if K != shapeB[1]: + raise RuntimeError(f"A inner dim ({K}) does not match weight ({shapeB[1]})") + if absmax.dtype != torch.float32: + raise RuntimeError(f"absmax must be float32, got {absmax.dtype}") + if bias is not None: + if bias.ndim != 1: + raise RuntimeError(f"bias must be 1D, got {bias.ndim}D") + if bias.dtype != A.dtype: + raise RuntimeError(f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})") + + quant_type_int = 1 if quant_type == "fp4" else 2 + + out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device) + stream = torch._C._cuda_getCurrentRawStream(A.device.index) + + if A.dtype == torch.bfloat16: + fn = lib.cgemm_4bit_bf16 + elif A.dtype == torch.float16: + fn = lib.cgemm_4bit_fp16 + elif A.dtype == torch.float32: + fn = lib.cgemm_4bit_fp32 + else: + raise RuntimeError(f"unsupported dtype {A.dtype}") + + with _cuda_device_of(A): + fn( + A.data_ptr(), + B.data_ptr(), + absmax.data_ptr(), + absmax_8bit.data_ptr() if absmax_8bit is not None else None, + absmax_code.data_ptr() if absmax_code is not None else None, + absmax_offset.data_ptr() if absmax_offset is not None else None, + out.data_ptr(), + bias.data_ptr() if bias is not None else None, + M, + N, + K, + blocksize, + quant_type_int, + stream, + ) + + return out + +else: + + @register_kernel("bitsandbytes::gemm_4bit", "cuda") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + bias: Optional[torch.Tensor] = None, + absmax_8bit: Optional[torch.Tensor] = None, + absmax_code: Optional[torch.Tensor] = None, + absmax_offset: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + K = A.shape[-1] + M = A.numel() // K + N = shapeB[0] + + if M == 1: + if K % blocksize == 0: + if absmax_8bit is not None: + absmax = ( + torch.ops.bitsandbytes.dequantize_blockwise.default( + absmax_8bit, absmax, absmax_code, 256, torch.float32 + ) + + absmax_offset + ) + + code = _get_4bit_code(quant_type, A.device) + out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + if bias is not None: + out = out + bias + return out + + warn( + f"inner dimension ({K}) is not aligned for fast kernel " + f"with blocksize={blocksize}, falling back to slower implementation.", + UserWarning, + ) + + return _gemm_4bit_default_impl( + A, + B, + shapeB, + absmax, + blocksize, + quant_type, + bias, + absmax_8bit=absmax_8bit, + absmax_code=absmax_code, + absmax_offset=absmax_offset, + ) + + """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 6f5eecdf2..2f276edc6 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -347,6 +347,31 @@ def _( ) +def _gemm_4bit_default_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + bias: Optional[torch.Tensor] = None, + absmax_8bit: Optional[torch.Tensor] = None, + absmax_code: Optional[torch.Tensor] = None, + absmax_offset: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # When nested, per-block scale = absmax_code[absmax_8bit[i]] * absmax[i // 256] + absmax_offset + if absmax_8bit is not None: + absmax = ( + torch.ops.bitsandbytes.dequantize_blockwise.default(absmax_8bit, absmax, absmax_code, 256, torch.float32) + + absmax_offset + ) + B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype) + return torch.nn.functional.linear(A, B_dq, bias) + + +register_kernel("bitsandbytes::gemm_4bit", "default")(_gemm_4bit_default_impl) + + MOMENTUM = 0 RMSPROP = 1 ADAGRAD = 2 diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py index 83c2ae89e..07e62ad58 100644 --- a/bitsandbytes/backends/mps/ops.py +++ b/bitsandbytes/backends/mps/ops.py @@ -6,10 +6,12 @@ from collections.abc import Sequence from math import prod +from typing import Optional import torch from ..._ops import register_kernel +from ..default.ops import _gemm_4bit_default_impl # --------------------------------------------------------------------------- # Quant-type mapping: BnB uses strings, our Metal kernel uses ints. @@ -143,3 +145,48 @@ def _( ) -> None: result = _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize) out.copy_(result) + + +@register_kernel("bitsandbytes::gemm_4bit", "mps") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + bias: Optional[torch.Tensor] = None, + absmax_8bit: Optional[torch.Tensor] = None, + absmax_code: Optional[torch.Tensor] = None, + absmax_offset: Optional[torch.Tensor] = None, +) -> torch.Tensor: + K = A.shape[-1] + M = A.numel() // K + N = shapeB[0] + + if absmax_8bit is not None: + absmax = ( + torch.ops.bitsandbytes.dequantize_blockwise.default(absmax_8bit, absmax, absmax_code, 256, torch.float32) + + absmax_offset + ) + + if M == 1: + if B.dtype != torch.uint8: + B = B.view(torch.uint8) + + k = _get_kernel() + result = k.gemv_4bit(A, B, absmax.view(N, -1), N, blocksize, _QUANT_MAP[quant_type]) + + if bias is not None: + result = result + bias + return result + + return _gemm_4bit_default_impl( + A, + B, + shapeB, + absmax, + blocksize, + quant_type, + bias, + ) diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index bef07169c..a63e59a99 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -62,6 +62,18 @@ ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} +# Cache 4-bit dequantization code tensors per (quant_type, device). +_code_4bit_cache: dict[tuple[str, torch.device], torch.Tensor] = {} + + +def _get_4bit_code(quant_type: str, device: torch.device) -> torch.Tensor: + key = (quant_type, device) + if key not in _code_4bit_cache: + from bitsandbytes.functional import get_4bit_type + + _code_4bit_cache[key] = get_4bit_type(quant_type, device=device) + return _code_4bit_cache[key] + def get_gaudi_sw_version(): """ diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index dfd0fb2d9..e95f9a0d2 100644 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,6 +1,7 @@ from collections.abc import Sequence import ctypes as ct import logging +from typing import Optional from packaging import version import torch @@ -9,7 +10,8 @@ from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib -from ..utils import triton_available +from ..default.ops import _gemm_4bit_default_impl +from ..utils import _get_4bit_code, triton_available logger = logging.getLogger(__name__) @@ -150,6 +152,52 @@ def _gemv_4bit_impl( ) +@register_kernel("bitsandbytes::gemm_4bit", "xpu") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + bias: Optional[torch.Tensor] = None, + absmax_8bit: Optional[torch.Tensor] = None, + absmax_code: Optional[torch.Tensor] = None, + absmax_offset: Optional[torch.Tensor] = None, +) -> torch.Tensor: + K = A.shape[-1] + M = A.numel() // K + + if M == 1: + if absmax_8bit is not None: + absmax = ( + torch.ops.bitsandbytes.dequantize_blockwise.default( + absmax_8bit, absmax, absmax_code, 256, torch.float32 + ) + + absmax_offset + ) + + code = _get_4bit_code(quant_type, A.device) + out = torch.ops.bitsandbytes.gemv_4bit.default(A, B, shapeB, absmax, code, blocksize) + + if bias is not None: + out = out + bias + return out + + return _gemm_4bit_default_impl( + A, + B, + shapeB, + absmax, + blocksize, + quant_type, + bias, + absmax_8bit=absmax_8bit, + absmax_code=absmax_code, + absmax_offset=absmax_offset, + ) + + # SYCL should be faster for xpu, so at first checking if it is available. if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): logger.info("Register sycl bitsandbytes kernels for XPU") @@ -229,6 +277,7 @@ def _( ) torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + elif triton_available: logger.info("Register triton bitsandbytes kernels for XPU") from ..triton import ops as triton_ops diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7796a8e84..70df070d7 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -114,6 +114,29 @@ def __init__(self, lib: ct.CDLL): lib.get_context.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p + # argtypes for the 4-bit GEMM entry points. + _gemm4bit_argtypes = [ + ct.c_void_p, # A + ct.c_void_p, # B + ct.c_void_p, # absmax + ct.c_void_p, # absmax_8bit + ct.c_void_p, # absmax_code + ct.c_void_p, # absmax_offset + ct.c_void_p, # out + ct.c_void_p, # bias + ct.c_int32, # M + ct.c_int32, # N + ct.c_int32, # K + ct.c_int32, # blocksize + ct.c_int32, # quant_type + ct.c_void_p, # stream + ] + for _fn_name in ("cgemm_4bit_bf16", "cgemm_4bit_fp16", "cgemm_4bit_fp32"): + _fn = getattr(lib, _fn_name, None) + if _fn is not None: + _fn.argtypes = _gemm4bit_argtypes + _fn.restype = None + class XpuBNBNativeLibrary(BNBNativeLibrary): """XPU native library with SYCL USM paged memory support.""" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0165a1288..d4ee98652 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1042,7 +1042,10 @@ def dequantize_4bit( quant_state.dtype, ) - if A.shape[0] == 1: # is transposed, transpose back + # BC shim: callers that pass the packed weight in transposed [1, (N*K+1)//2] form + # receive the output transposed back to [K, N]. bnb's own paths no longer trigger + # this since B is normalized to [(N*K+1)//2, 1] at the matmul_4bit entry point. + if A.shape[0] == 1: return out.t() return out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bfd41d5dd..ebc0b0943 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -611,18 +611,14 @@ def forward(self, x: torch.Tensor): quant_state = self.weight.quant_state if ( - not getattr(quant_state, "packing_format_for_cpu", False) - and x.device.type == "cpu" + x.device.type == "cpu" and self.support_avx512bf16_for_cpu and not self.training and x.requires_grad == False + and not getattr(quant_state, "packing_format_for_cpu", False) ): self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state) - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -631,10 +627,14 @@ def forward(self, x: torch.Tensor): if self.compute_dtype is not None: x = x.to(self.compute_dtype) - bias = None if self.bias is None else self.bias.to(self.compute_dtype) - weight = self.weight if getattr(quant_state, "packing_format_for_cpu", False) else self.weight.t() + bias = self.bias + if bias is not None: + if bias.dtype != x.dtype: + # TODO: do we need to cast bias like this? + bias.data = bias.data.to(x.dtype) + bias = bias.to(self.compute_dtype) - return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, self.weight, bias=bias, quant_state=quant_state).to(inp_dtype) class LinearFP4(Linear4bit): diff --git a/csrc/gemm_4bit.cu b/csrc/gemm_4bit.cu new file mode 100644 index 000000000..557acabb4 --- /dev/null +++ b/csrc/gemm_4bit.cu @@ -0,0 +1,157 @@ +// C API for the custom 4-bit GEMM. +// Dispatches to SIMT or MMA kernels based on GPU architecture and shape. +// Computes out[M, N] = A[M, K] @ B[N, K]^T. All pointers are device memory. + +#include +#include +#include +#include +#include + +#include "gemm_4bit_simt.cuh" +#include "gemm_4bit_sm75.cuh" +#include "gemm_4bit_sm80.cuh" + +// 16-entry cache indexed by device ID. num_sms==0 means not yet populated. +// Static storage is zero-initialized, so all entries start unpopulated (num_sms==0). +GpuProps get_gpu_props() { + static GpuProps cache[16]; + int dev = 0; + cudaGetDevice(&dev); + if (dev < 16 && cache[dev].num_sms == 0) { + cudaDeviceGetAttribute(&cache[dev].num_sms, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&cache[dev].cc_major, cudaDevAttrComputeCapabilityMajor, dev); + cudaDeviceGetAttribute(&cache[dev].cc_minor, cudaDevAttrComputeCapabilityMinor, dev); + } + return cache[dev]; +} + +/// @brief Fused 4-bit dequantize + GEMM. Computes out[M,N] = A[M,K] @ B[N,K]^T + bias. +/// +/// Dispatches to SIMT (sm60+) or MMA (sm75 fp16, sm80+ bf16/fp16) based on GPU arch and shape. +/// fp32 always uses SIMT. Supports single-level and double-quantized (nested) absmax. +/// +/// @tparam T Input/output dtype (`__nv_bfloat16`, `half`, or `float`) +template +static void gemm_4bit( + // clang-format off + const T* A, // inputs [M, K] + const uint8_t* B, // packed 4-bit weights [N, K/2] + const float* absmax, // fp32 absmax [N*K/blocksize] or [ceil(N*K/(blocksize*256))] when nested + const uint8_t* absmax_8bit, // [N*K/blocksize] uint8 compressed absmax; nullptr = non-nested + const float* absmax_code, // [256] codebook for 8bit absmax + const float* absmax_offset, // scalar; nullptr = non-nested + T* out, // [M, N] + const T* bias, // [N] optional, nullptr = no bias + int M, int N, int K, // problem shape + int blocksize, // elements per quantization block + int quant_type, // 1 = FP4, 2 = NF4 + cudaStream_t stream // CUDA stream + // clang-format on +) { + constexpr bool is_fp32 = std::is_same_v; + + // fp32 and M<=3 are always SIMT regardless of GPU -- skip the props lookup. + if (is_fp32 || M <= 3) { + launch_gemm_4bit_simt( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, stream + ); + return; + } + +#if defined(BNB_HAS_GEMM4BIT_SM75) || defined(BNB_HAS_GEMM4BIT_SM80) + const GpuProps gpu = get_gpu_props(); + const int num_sms = gpu.num_sms; + const int cc_maj = gpu.cc_major; + const int cc_min = gpu.cc_minor; + + const bool hbm_arch = (cc_maj == 8 && cc_min == 0) || cc_maj == 9 || cc_maj == 10; + const bool gddr_arch = !hbm_arch && cc_maj >= 8; + const int mma_blocks = ((M + 31) / 32) * ((N + 63) / 64); + + // sm86/sm89/sm120 with >= 48 SMs: high GDDR bandwidth means SIMT wins at M=4. + const bool highbw_gddr = (cc_maj == 8 && (cc_min == 6 || cc_min == 9) && num_sms >= 48) || + (cc_maj == 12 && cc_min == 0 && num_sms >= 48); + + // Below 2/3-wave at M<=8: SIMT keeps more warps in flight. + const bool undersubscribed = + (M <= 8 && mma_blocks * 3 <= num_sms * 2) || (hbm_arch && M == 4 && mma_blocks <= num_sms); + + // sm89 (>=60 SMs) and sm120 (>=48 SMs): at M<=6 with wide N, SIMT saturates + // bandwidth more efficiently than blocked MMA. + const bool wide_n_simt = + M <= 6 && mma_blocks >= num_sms && + ((cc_maj == 8 && cc_min == 9 && num_sms >= 60) || (cc_maj == 12 && cc_min == 0 && num_sms >= 48)); + + // GDDR tall-K (K>N): K-loop too long relative to output tile at small M. + const bool tall_k_simt = gddr_arch && K > N && M <= 17 && mma_blocks * 3 < num_sms; + + const bool use_simt = (M == 4 && highbw_gddr) || undersubscribed || wide_n_simt || tall_k_simt || + (M <= 16 && mma_blocks * 4 <= num_sms) || (M <= 32 && mma_blocks * 8 <= num_sms) || + (K % 64 != 0); // MMA requirement + + if (!use_simt) { +#if defined(BNB_HAS_GEMM4BIT_SM80) + if (cc_maj >= 8) { + if constexpr (!is_fp32) { + launch_gemm_4bit_sm80_m16n8k16( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, + gpu, stream + ); + return; + } + } +#endif +#if defined(BNB_HAS_GEMM4BIT_SM75) + if (cc_maj == 7 && cc_min >= 5) { + // bf16 has no sm75 tensor core support; falls through to SIMT. + if constexpr (std::is_same_v) { + launch_gemm_4bit_sm75_m16n8k8( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, + gpu, stream + ); + return; + } + } +#endif + } +#endif // BNB_HAS_GEMM4BIT_SM75 || BNB_HAS_GEMM4BIT_SM80 + + launch_gemm_4bit_simt( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, stream + ); +} + +extern "C" { + +void cgemm_4bit_bf16( + const __nv_bfloat16* A, const uint8_t* B, const float* absmax, const uint8_t* absmax_8bit, const float* absmax_code, + const float* absmax_offset, __nv_bfloat16* out, const __nv_bfloat16* bias, int M, int N, int K, int blocksize, + int quant_type, cudaStream_t stream +) { + gemm_4bit<__nv_bfloat16>( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, stream + ); +} + +void cgemm_4bit_fp16( + const half* A, const uint8_t* B, const float* absmax, const uint8_t* absmax_8bit, const float* absmax_code, + const float* absmax_offset, half* out, const half* bias, int M, int N, int K, int blocksize, int quant_type, + cudaStream_t stream +) { + gemm_4bit( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, stream + ); +} + +void cgemm_4bit_fp32( + const float* A, const uint8_t* B, const float* absmax, const uint8_t* absmax_8bit, const float* absmax_code, + const float* absmax_offset, float* out, const float* bias, int M, int N, int K, int blocksize, int quant_type, + cudaStream_t stream +) { + gemm_4bit( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, stream + ); +} + +} // extern "C" diff --git a/csrc/gemm_4bit_common.cuh b/csrc/gemm_4bit_common.cuh new file mode 100644 index 000000000..3febc8741 --- /dev/null +++ b/csrc/gemm_4bit_common.cuh @@ -0,0 +1,205 @@ +#pragma once + +// Shared types and utilities for 4bit GEMM kernels. + +// GPU properties queried once per device and cached in gemm_4bit.cu. +// Passed through dispatch into MMA launchers to avoid repeated cudaGetDevice calls. +struct GpuProps { + int num_sms, cc_major, cc_minor; +}; + +#include +#include +#include + +// NF4 dequantization LUT +static __device__ __constant__ float NF4_LUT_F32[16] = { + -1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f, +}; + +// FP4 dequantization LUT +static __device__ __constant__ float FP4_LUT_F32[16] = { + 0.0f, // 0b0000 + 0.005208333333f, // 0b0001 + 0.66666667f, // 0b0010 + 1.0f, // 0b0011 + 0.33333333f, // 0b0100 + 0.5f, // 0b0101 + 0.16666667f, // 0b0110 + 0.25f, // 0b0111 + -0.0f, // 0b1000 + -0.005208333333f, // 0b1001 + -0.66666667f, // 0b1010 + -1.0f, // 0b1011 + -0.33333333f, // 0b1100 + -0.5f, // 0b1101 + -0.16666667f, // 0b1110 + -0.25f, // 0b1111 +}; + +// MMA accumulator fragment for m16n8k* +struct FragC { + float x[4]; +}; + +/// @brief Compile-time layout of warps for mma.sync m16n8k* instructions +/// +/// Warps are split `WARPS_M x WARPS_N` where `WARPS_M * WARPS_N == NUM_WARPS_VAL`. +/// Default split: MT<=32 -> 2Mx(N/2); MT>32 -> 4Mx(N/4). +/// +/// @tparam MT Number of rows (e.g. 32, 64, 128) +/// @tparam NT Number of columns (e.g. 32, 64, 128, 256) +/// @tparam NUM_WARPS_VAL Total number of warps (e.g. 4, 8, 16); must be divisible by WARPS_M +/// @tparam MMA_M_VAL MMA M dimension (e.g. 16) +/// @tparam MMA_N_VAL MMA N dimension (e.g. 8) +template struct MmaWarpLayout { + static constexpr int WARPS_M = (MT <= 32) ? 2 : 4; + static constexpr int WARPS_N = NUM_WARPS_VAL / WARPS_M; + static constexpr int WARP_M = MT / WARPS_M; + static constexpr int WARP_MMA_M = WARP_M / MMA_M_VAL; + static constexpr int WARP_N = NT / WARPS_N; + static constexpr int WARP_MMA_N = WARP_N / MMA_N_VAL; +}; + +// Indicates whether `T` is a 16-bit float type supported in our kernels (currently `__nv_bfloat16` or `half`). +template constexpr bool is_16bit_float_v = std::is_same_v || std::is_same_v; + +// Convert two floats to a T vector-pair (half2 or __nv_bfloat162), with rounding. +template __device__ __forceinline__ auto make_vec2(float a, float b) { + static_assert(is_16bit_float_v, "make_vec2: T must be __nv_bfloat16 or half"); + if constexpr (std::is_same_v) + return __floats2bfloat162_rn(a, b); + else + return __floats2half2_rn(a, b); +} + +// Single float broadcast into both lanes of a T vector-pair. +template __device__ __forceinline__ auto broadcast_vec2(float x) { + static_assert(is_16bit_float_v, "broadcast_vec2: T must be __nv_bfloat16 or half"); + if constexpr (std::is_same_v) + return __float2bfloat162_rn(x); + else + return __float2half2_rn(x); +} + +// T vector-pair -> float2. +template __device__ __forceinline__ float2 vec2_to_float2(T2 v) { + if constexpr (std::is_same_v) + // __bfloat1622float2 is gated behind sm80+ in CUDA < 12.2. Two + // __bfloat162float calls emit identical PTX (cvt.f32.bf16 on sm90+, + // mov.b32 on earlier targets). + return {__bfloat162float(v.x), __bfloat162float(v.y)}; + else + return __half22float2(v); +} + +// Two uint32 values holding uint16 bit patterns -> T vector-pair. +// Used to reassemble warp-shuffled LUT entries. +template __device__ __forceinline__ auto vec2_from_u16bits(uint32_t hi, uint32_t lo) { + static_assert(is_16bit_float_v, "vec2_from_u16bits: T must be __nv_bfloat16 or half"); + using T2 = std::conditional_t, __nv_bfloat162, half2>; + const uint16_t hi16 = (uint16_t)hi, lo16 = (uint16_t)lo; + return T2{*reinterpret_cast(&hi16), *reinterpret_cast(&lo16)}; +} + +// Two f32 accumulators -> packed uint32 of two T values, via PTX cvt. +template __device__ __forceinline__ uint32_t cvt_f32x2_to_packed(float hi, float lo) { + static_assert(is_16bit_float_v, "cvt_f32x2_to_packed: T must be __nv_bfloat16 or half"); + uint32_t result; + if constexpr (std::is_same_v) { + asm("cvt.rn.bf16x2.f32 %0, %1, %2;" : "=r"(result) : "f"(hi), "f"(lo)); + } else { + // cvt.rn.f16x2.f32 requires sm80+; __floats2half2_rn is used instead + // so the compiler picks the right instruction per target. + // lo -> bits [15:0], hi -> bits [31:16]. + const half2 h2 = __floats2half2_rn(lo, hi); + result = *reinterpret_cast(&h2); + } + return result; +} + +/// @brief MMA epilogue: store f32 accumulators to C with optional bias, for m16n8k* layouts. +/// +/// D-fragment layout (mma.sync m16n8k*): +/// group = lane/4, tid = lane%4 +/// d[0] -> C[group, tid*2] d[1] -> C[group, tid*2+1] +/// d[2] -> C[group+MMA_M/2, tid*2] d[3] -> C[group+MMA_M/2, tid*2+1] +/// +/// @tparam T Output type (__nv_bfloat16 or half) +/// @tparam WARP_MMA_M Number of M-direction MMA tiles per warp +/// @tparam WARP_MMA_N Number of N-direction MMA tiles per warp +/// @tparam MMA_M_VAL MMA M dimension (default 16) +/// @tparam MMA_N_VAL MMA N dimension (default 8) +/// @param C Output pointer [M, N] +/// @param accum f32 accumulator fragments [WARP_MMA_M][WARP_MMA_N] +/// @param bm Block M offset (blockIdx.x * MT) +/// @param bn Block N offset (blockIdx.y * NT) +/// @param wm_off Warp M offset within block +/// @param wn_off Warp N offset within block +/// @param M Total output rows +/// @param N Total output columns +/// @param lane_id Lane index within warp (0-31) +/// @param bias Optional bias [N]; nullptr = no bias +template +__device__ __forceinline__ void mma_store_accum( + T* C, const FragC (&accum)[WARP_MMA_M][WARP_MMA_N], int bm, int bn, int wm_off, int wn_off, int M, int N, + int lane_id, + const T* bias = nullptr // [N], optional; accumulated in fp32 before downcast +) { + constexpr int ROW_STRIDE = MMA_M_VAL / 2; + const int group = lane_id / 4; + const int tid = lane_id % 4; + +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_MMA_N; wn++) { + const int base_m = bm + wm_off + wm * MMA_M_VAL; + const int base_n = bn + wn_off + wn * MMA_N_VAL; + const int m0 = base_m + group; + const int m1 = base_m + group + ROW_STRIDE; + const int n0 = base_n + tid * 2; + const int n1 = n0 + 1; + + // Bias is added in fp32 before the downcast. + const float b0 = bias ? static_cast(bias[n0]) : 0.0f; + const float b1 = (bias && n1 < N) ? static_cast(bias[n1]) : 0.0f; + + if (m0 < M) { + if (__builtin_expect(n1 < N, 1)) { + const uint32_t c01 = cvt_f32x2_to_packed(accum[wm][wn].x[1] + b1, accum[wm][wn].x[0] + b0); + // clang-format off + asm volatile("st.global.cs.b32 [%0], %1;" :: "l"(&C[m0 * N + n0]), "r"(c01)); + // clang-format on + } else if (n0 < N) { + C[m0 * N + n0] = static_cast(accum[wm][wn].x[0] + b0); + } + } + if (m1 < M) { + if (__builtin_expect(n1 < N, 1)) { + const uint32_t c23 = cvt_f32x2_to_packed(accum[wm][wn].x[3] + b1, accum[wm][wn].x[2] + b0); + // clang-format off + asm volatile("st.global.cs.b32 [%0], %1;" :: "l"(&C[m1 * N + n0]), "r"(c23)); + // clang-format on + } else if (n0 < N) { + C[m1 * N + n0] = static_cast(accum[wm][wn].x[2] + b0); + } + } + } + } +} diff --git a/csrc/gemm_4bit_simt.cu b/csrc/gemm_4bit_simt.cu new file mode 100644 index 000000000..fdcd12d53 --- /dev/null +++ b/csrc/gemm_4bit_simt.cu @@ -0,0 +1,298 @@ +// SIMT 4-bit GEMM kernel. Compiles for all architectures (sm60+). + +#include +#include +#include +#include + +#include "gemm_4bit_common.cuh" +#include "gemm_4bit_simt.cuh" + +// Warps per block; each warp owns one N-column. CTA size = WARPS_PER_BLOCK * 32. +static constexpr int WARPS_PER_BLOCK = 4; + +// Element-wise multiply of a half2 vector pair. +__device__ __forceinline__ half2 vec2_mul(half2 a, half2 b) { return __hmul2(a, b); } + +// Element-wise multiply of a __nv_bfloat162 vector pair. +__device__ __forceinline__ __nv_bfloat162 vec2_mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && __CUDA_VERSION__ < 12020 + // Falls back to fp32 round-trip on <= sm75 + CUDA < 12.2, which lacks a __hmul2 + // overload for bf16 on Turing; identical to what CUDA 12.2+ emits for that target. + return __floats2bfloat162_rn( + __bfloat162float(a.x) * __bfloat162float(b.x), __bfloat162float(a.y) * __bfloat162float(b.y) + ); +#else + return __hmul2(a, b); +#endif +} + +/// @brief Fused 4-bit dequantize + GEMM with optional nested absmax and bias. +/// Computes C[M,N] = A[M,K] @ B[N,K]^T + bias. +/// +/// One warp owns one N-column. All 32 lanes split K in parallel, then warp-reduce. +/// B is dequantized once per outer K step and reused across all M rows. +/// Supports single-level and double-quantized (nested) absmax. +/// +/// Dtype paths: +/// bf16/fp16: LUT as uint16-in-uint32 for warp shuffle; T2 pair math; 1x uint4 A load per sub-iter +/// fp32: LUT as float for warp shuffle; scalar multiply; 2x uint4 A loads per sub-iter +/// +/// Grid: (ceil(N/WARPS_PER_BLOCK), ceil(M/M_BLOCK)) +/// +/// @tparam T Input/output dtype (`__nv_bfloat16`, `half`, or `float`) +/// @tparam M_BLOCK M rows per block +template +__global__ void __launch_bounds__(WARPS_PER_BLOCK * 32) gemm_4bit_simt( + // clang-format off + const T* __restrict__ A, // inputs [M, K] + const uint8_t* __restrict__ B, // packed 4-bit weights [N, K/2] + const float* __restrict__ absmax, // fp32 absmax [N, K/blocksize] or + // [ceil(N*K/(blocksize*256))] when nested + const uint8_t* __restrict__ absmax_8bit, // [N, K/blocksize] uint8 compressed absmax; + // nullptr = non-nested + const float* __restrict__ absmax_code, // [256] codebook for 8bit absmax + const float* __restrict__ absmax_offset, // scalar; nullptr = non-nested + T* __restrict__ C, // [M, N] + const T* __restrict__ bias, // [N] optional, nullptr = no bias + int M, int N, int K, // problem shape + int blocksize, // elements per quantization block + int quant_type // 1 = FP4, 2 = NF4 + // clang-format on +) { + // T2 is the vector-pair type; for float it resolves to float2 but is unused. + using T2 = std::conditional_t< + std::is_same_v, __nv_bfloat162, std::conditional_t, half2, float2>>; + + const float absmax_offset_f = absmax_8bit ? __ldg(absmax_offset) : 0.0f; + + const int blocksize_log2 = __ffs(blocksize) - 1; + + constexpr int NUM_VAL = 32; // 4-bit elements per lane per outer K step + constexpr int K_STRIDE = 32 * NUM_VAL; // = 1024 K elements per outer step + + const int lane_id = threadIdx.x & 31; + const int warp_id = threadIdx.x >> 5; + const int warp_n = blockIdx.x * WARPS_PER_BLOCK + warp_id; + const int base_m = blockIdx.y * M_BLOCK; + + if (warp_n >= N) + return; + + // Each lane loads its LUT entry (lane_id < 16) for warp-shuffle dequant. + // bf16/fp16: centroid as uint16-in-uint32 for __shfl_sync over uint32. + // fp32: centroid as float for __shfl_sync over float. + [[maybe_unused]] uint32_t my_lut_u32 = 0u; + [[maybe_unused]] float my_lut_f32_shfl = 0.0f; + + if (lane_id < 16) { + const float* lut = (quant_type == 1) ? FP4_LUT_F32 : NF4_LUT_F32; + if constexpr (std::is_same_v) + my_lut_f32_shfl = lut[lane_id]; + else { + const T centroid = static_cast(lut[lane_id]); + my_lut_u32 = (uint32_t)(*reinterpret_cast(¢roid)); + } + } + + const int blk_per_row = K >> blocksize_log2; + + // Per-M accumulators. M_BLOCK is compile-time so the loop fully unrolls. + float acc[M_BLOCK]; +#pragma unroll + for (int m = 0; m < M_BLOCK; m++) + acc[m] = 0.f; + + const int m_valid = min(M_BLOCK, max(0, M - base_m)); + + // All 32 lanes run the same number of K iterations for `__shfl_sync` convergence. + // Inactive lanes load b_packed4={0}, which decodes to 0 and contributes nothing. + const int num_groups = (K + K_STRIDE - 1) / K_STRIDE; + + for (int g = 0; g < num_groups; g++) { + const int inner_k = g * K_STRIDE + lane_id * NUM_VAL; + const bool lane_active = (inner_k < K); + + // Scale: one absmax per lane per outer K step. + float scale_f = 0.0f; + if (lane_active) { + const int blk_idx = warp_n * blk_per_row + (inner_k >> blocksize_log2); + if (absmax_8bit) { + // absmax_8bit[blk_idx] is a uint8 index into absmax_code. + // absmax[blk_idx >> 8] is the fp32 state2 scale (one per 256 blocks). + // absmax_offset is subtracted at quantize time and re-added here. + scale_f = __ldg(&absmax_code[absmax_8bit[blk_idx]]) * __ldg(&absmax[blk_idx >> 8]) + absmax_offset_f; + } else { + scale_f = __ldg(&absmax[blk_idx]); + } + } + + // Load 16 packed 4-bit bytes (.cs: streaming, bypasses L1). + // Inactive lanes keep b_packed4 = {0}. + uint4 b_packed4 = {0u, 0u, 0u, 0u}; + if (lane_active) { + const uint32_t* bptr = reinterpret_cast(B + warp_n * (K / 2) + inner_k / 2); + // clang-format off + asm volatile( + "ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(b_packed4.x), "=r"(b_packed4.y), + "=r"(b_packed4.z), "=r"(b_packed4.w) + : "l"(bptr) + ); + // clang-format on + } + const uint8_t* b_bytes = reinterpret_cast(&b_packed4); + + // Decode B and accumulate. + // bf16/fp16: uint16-in-uint32 LUT, hmul2 vector math, 1x uint4 A load per sub-iter. + // fp32: float LUT, scalar multiply, 2x uint4 A loads per sub-iter. + [[maybe_unused]] T2 scale_x2{}; + if constexpr (!std::is_same_v) + scale_x2 = broadcast_vec2(scale_f); + +#pragma unroll + for (int sub = 0; sub < 4; sub++) { + // Decode 4 B bytes (8 nibbles) into 8 dequantized values. + // hi nibble (>>4) = lower K index, lo nibble (&0xf) = higher K index. + [[maybe_unused]] T2 b_chunk[4]; + [[maybe_unused]] float b_dq[8]; +#pragma unroll + for (int j = 0; j < 4; j++) { + const uint8_t byte = b_bytes[sub * 4 + j]; + if constexpr (std::is_same_v) { + b_dq[j * 2] = __shfl_sync(0xffffffff, my_lut_f32_shfl, byte >> 4) * scale_f; + b_dq[j * 2 + 1] = __shfl_sync(0xffffffff, my_lut_f32_shfl, byte & 0x0f) * scale_f; + } else { + const uint32_t hi = __shfl_sync(0xffffffff, my_lut_u32, byte >> 4); + const uint32_t lo = __shfl_sync(0xffffffff, my_lut_u32, byte & 0x0f); + b_chunk[j] = vec2_mul(vec2_from_u16bits(hi, lo), scale_x2); + } + } + + if (lane_active) { +#pragma unroll + for (int m = 0; m < M_BLOCK; m++) { + if (m >= m_valid) + break; + const int m_global = base_m + m; + const int a_k = inner_k + sub * 8; + + if constexpr (std::is_same_v) { + // 8 floats = 2x uint4 (cached; A rows reused across N-tiles) + const uint4 a4a = *reinterpret_cast(&A[m_global * K + a_k]); + const uint4 a4b = *reinterpret_cast(&A[m_global * K + a_k + 4]); + const float* fa = reinterpret_cast(&a4a); + const float* fb = reinterpret_cast(&a4b); +#pragma unroll + for (int k = 0; k < 4; k++) { + acc[m] += fa[k] * b_dq[k]; + acc[m] += fb[k] * b_dq[k + 4]; + } + } else { + // 8 T elements as uint4 (cached; A rows reused across N-tiles). + // sm90+: asm fence gives better occupancy. +#if __CUDA_ARCH__ >= 900 + uint4 a_packed4; + // clang-format off + asm volatile( + "ld.global.ca.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a_packed4.x), "=r"(a_packed4.y), + "=r"(a_packed4.z), "=r"(a_packed4.w) + : "l"(&A[m_global * K + a_k]) + ); + // clang-format on +#else + const uint4 a_packed4 = *reinterpret_cast(&A[m_global * K + a_k]); +#endif + const T2* a_pairs = reinterpret_cast(&a_packed4); +#pragma unroll + for (int k = 0; k < 4; k++) { + const float2 p = vec2_to_float2(vec2_mul(a_pairs[k], b_chunk[k])); + acc[m] += p.x + p.y; + } + } + } + } + } + } + +#pragma unroll + // Warp reduce: sum acc[m] across all 32 lanes. + for (int m = 0; m < M_BLOCK; m++) { + if (m >= m_valid) + break; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + acc[m] += __shfl_down_sync(0xffffffff, acc[m], offset); + } + + if (lane_id == 0) { + const float bias_f = bias ? static_cast(bias[warp_n]) : 0.0f; +#pragma unroll + for (int m = 0; m < M_BLOCK; m++) { + if (m >= m_valid) + break; + C[(base_m + m) * N + warp_n] = static_cast(acc[m] + bias_f); + } + } +} + +/// @brief Host launcher for the SIMT 4-bit GEMM kernel. Selects M_BLOCK at compile +/// time based on M (exact for M=1..8, clamped to 8 above). +/// @tparam T Input/output dtype (`__nv_bfloat16`, `half`, or `float`) +template +void launch_gemm_4bit_simt( + // clang-format off + const T* A, // inputs [M, K] + const uint8_t* B, // packed 4-bit weights [N, K/2] + const float* absmax, // fp32 absmax [N*K/blocksize] or [ceil(N*K/(blocksize*256))] when nested + const uint8_t* absmax_8bit, // [N*K/blocksize] uint8 compressed absmax; nullptr = non-nested + const float* absmax_code, // [256] codebook for 8bit absmax + const float* absmax_offset, // scalar; nullptr = non-nested + T* C, // output [M, N] + const T* bias, // [N] optional, nullptr = no bias + int M, int N, int K, // problem shape + int blocksize, // elements per quantization block + int quant_type, // 1 = FP4, 2 = NF4 + cudaStream_t stream // CUDA stream + // clang-format on +) { + const int n_blocks = (N + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; + + // M=1..8: M_BLOCK == M, so the inner m-loop fully unrolls at compile time. + // M>8: M_BLOCK=8, ceil(M/8) grid rows. + auto launch = [&](auto mb_tag) { + constexpr int MB = decltype(mb_tag)::value; + const int grid_y = (M + MB - 1) / MB; + gemm_4bit_simt<<>>( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type + ); + }; + + // clang-format off + switch (min(M, 8)) { + case 1: launch(std::integral_constant{}); break; + case 2: launch(std::integral_constant{}); break; + case 3: launch(std::integral_constant{}); break; + case 4: launch(std::integral_constant{}); break; + case 5: launch(std::integral_constant{}); break; + case 6: launch(std::integral_constant{}); break; + case 7: launch(std::integral_constant{}); break; + default: launch(std::integral_constant{}); break; + } + // clang-format on +} + +// Explicit instantiations for supported dtypes. +template void launch_gemm_4bit_simt<__nv_bfloat16>( + const __nv_bfloat16*, const uint8_t*, const float*, const uint8_t*, const float*, const float*, __nv_bfloat16*, + const __nv_bfloat16*, int, int, int, int, int, cudaStream_t +); +template void launch_gemm_4bit_simt( + const half*, const uint8_t*, const float*, const uint8_t*, const float*, const float*, half*, const half*, int, int, + int, int, int, cudaStream_t +); +template void launch_gemm_4bit_simt( + const float*, const uint8_t*, const float*, const uint8_t*, const float*, const float*, float*, const float*, int, + int, int, int, int, cudaStream_t +); diff --git a/csrc/gemm_4bit_simt.cuh b/csrc/gemm_4bit_simt.cuh new file mode 100644 index 000000000..e29f08dab --- /dev/null +++ b/csrc/gemm_4bit_simt.cuh @@ -0,0 +1,17 @@ +#pragma once +// Launcher declarations for the SIMT 4-bit GEMM kernel. +// T must be __nv_bfloat16, half, or float. +// Compiles for all architectures (sm60+, no tensor core requirement). +// absmax_8bit == nullptr selects the non-nested (standard) path. + +#include +#include +#include +#include + +template +void launch_gemm_4bit_simt( + const T* A, const uint8_t* B, const float* absmax, const uint8_t* absmax_8bit, const float* absmax_code, + const float* absmax_offset, T* C, const T* bias, int M, int N, int K, int blocksize, int quant_type, + cudaStream_t stream +); diff --git a/csrc/gemm_4bit_sm75.cu b/csrc/gemm_4bit_sm75.cu new file mode 100644 index 000000000..be5ec99a3 --- /dev/null +++ b/csrc/gemm_4bit_sm75.cu @@ -0,0 +1,413 @@ +// sm75 MMA (mma.sync.aligned.m16n8k8) 4-bit GEMM kernel. fp16 only. + +#include +#include +#include + +#include "gemm_4bit_common.cuh" +#include "gemm_4bit_sm75.cuh" + +static constexpr int K_CHUNK = 64; +[[maybe_unused]] static constexpr int MMA_M = 16; +[[maybe_unused]] static constexpr int MMA_N = 8; +[[maybe_unused]] static constexpr int MMA_K = 8; + +static constexpr int NUM_WARPS = 8; +static constexpr int CTA_SIZE = NUM_WARPS * 32; + +// K_CHUNK + 8 stride: 16-byte row alignment, limits bank conflicts to 4-way for K_CHUNK=64. +static constexpr int SMEM_A_STRIDE = K_CHUNK + 8; // 72 half elements +static constexpr int SMEM_B_STRIDE = K_CHUNK + 8; + +/// @brief In-place warp-level MMA: accum += A * B (fp16, m16n8k8). +/// Called once per K chunk to accumulate the full matrix product. +/// +/// Executes mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32. +/// Fragments are distributed across warp lanes per PTX spec. +/// +/// @param accum In-place f32 accumulator (4 regs per lane) +/// @param a A fragment: 16x8 fp16 operand (2 regs per lane) +/// @param b B fragment: 8x8 fp16 operand (1 reg per lane) +__device__ __forceinline__ void mma_m16n8k8(FragC& accum, const uint32_t a[2], const uint32_t b[1]) { + // clang-format off + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%0,%1,%2,%3};\n" + : "+f"(accum.x[0]), "+f"(accum.x[1]), "+f"(accum.x[2]), "+f"(accum.x[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]) + ); + // clang-format on +} + +// A: m16 x k8 from row-major smem. ldmatrix.x2 loads two stacked m8xk8 halves. +// Thread t provides the address for row m_off + (t % 16) at column k_off. +// Lanes 16-31 mirror 0-15 (both sets address the same 16 rows for .x2). +__device__ __forceinline__ void load_A_frag(uint32_t frag[2], const half* smem_a, int m_off, int k_off, int lane) { + const int m_row = m_off + (lane % 16); + const uint32_t addr = static_cast(__cvta_generic_to_shared(&smem_a[m_row * SMEM_A_STRIDE + k_off])); + // clang-format off + asm volatile( + "ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(frag[0]), "=r"(frag[1]) + : "r"(addr) + ); + // clang-format on +} + +// B: k8 x n8 from row-major smem (= col-major [k][n]). ldmatrix.x1 loads one m8xk8. +// Thread t addresses row n_off + (t % 8) at column k_off. +__device__ __forceinline__ void load_B_frag(uint32_t frag[1], const half* smem_b, int n_off, int k_off, int lane) { + const int n_row = n_off + (lane % 8); + const uint32_t addr = static_cast(__cvta_generic_to_shared(&smem_b[n_row * SMEM_B_STRIDE + k_off])); + // clang-format off + asm volatile( + "ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(frag[0]) + : "r"(addr) + ); + // clang-format on +} + +template static constexpr int smem_bytes_for() { + return 2 * (MT * SMEM_A_STRIDE + NT * SMEM_B_STRIDE) * static_cast(sizeof(T)); +} + +/// @brief Fused 4-bit dequantize + MMA GEMM for sm75 (fp16 only). +/// Computes C[M,N] = A[M,K] @ B[N,K]^T + bias. +/// +/// Layout: +/// A: [M, K] row-major, half (activations) +/// B: [N, K/2] row-major, packed uint8 (2 nibbles per byte, weights) +/// C: [M, N] row-major, half (output) +/// +/// MMA: `mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32` +/// A operand: row-major [M, K] +/// B operand: col-major [K, N] (B [N, K] row-major reinterpreted) +/// +/// Double-buffered smem pipeline. Supports optional nested absmax and bias. +/// +/// Smem per double-buffer (K_CHUNK=64, half=2 bytes): +/// MT=32, NT=64: 27 KB MT=32, NT=128: 45 KB +/// MT=64, NT=64: 36 KB MT=64, NT=128: 54 KB +/// +/// @tparam T Input/output dtype (must be `half`; static_assert enforces) +/// @tparam MT M tile size (32 or 64) +/// @tparam NT N tile size (64 or 128) +template +__global__ void __launch_bounds__(CTA_SIZE) gemm_4bit_sm75_m16n8k8( + // clang-format off + const T* __restrict__ A, // inputs [M, K] + const uint8_t* __restrict__ B, // packed 4-bit weights [N, K/2] + const float* __restrict__ absmax, // fp32 absmax [N, K/blocksize] or + // [ceil(N*K/(blocksize*256))] when nested + const uint8_t* __restrict__ absmax_8bit, // [N, K/blocksize] uint8 compressed absmax; + // nullptr = non-nested + const float* __restrict__ absmax_code, // [256] codebook for 8bit absmax + const float* __restrict__ absmax_offset, // scalar; nullptr = non-nested + T* __restrict__ C, // [M, N] + const T* __restrict__ bias, // [N] optional, nullptr = no bias + int M, int N, int K, // problem shape + int blocksize, // elements per quantization block + int quant_type // 1 = FP4, 2 = NF4 + // clang-format on +) { +#if __CUDA_ARCH__ == 750 + static_assert(std::is_same_v, "sm75 MMA requires fp16 (half)"); + static_assert(MT == 32 || MT == 64, "MT must be 32 or 64"); + static_assert(NT == 64 || NT == 128, "NT must be 64 or 128"); + + using WL = MmaWarpLayout; + constexpr int WARPS_M = WL::WARPS_M; + constexpr int WARPS_N = WL::WARPS_N; + constexpr int WARP_M = WL::WARP_M; + constexpr int WARP_MMA_M = WL::WARP_MMA_M; + constexpr int WARP_N = WL::WARP_N; + constexpr int WARP_MMA_N = WL::WARP_MMA_N; + + static_assert(MT >= WARPS_M * MMA_M, "MT too small for warp layout"); + static_assert(NT >= WARPS_N * MMA_N, "NT too small for warp layout"); + + // Packed bytes of B per thread per K-chunk (NT=64: 8 bytes; NT=128: 16 bytes). + constexpr int B_BYTES = NT * (K_CHUNK / 2) / CTA_SIZE; + static_assert(B_BYTES == 8 || B_BYTES == 16, "unexpected B bytes per thread"); + + extern __shared__ char smem_raw[]; + constexpr int buf_offset = (MT * SMEM_A_STRIDE + NT * SMEM_B_STRIDE) * sizeof(T); + auto smem_a_buf = [&](int buf) -> half* { return reinterpret_cast(smem_raw + buf * buf_offset); }; + auto smem_b_buf = [&](int buf) -> half* { + return reinterpret_cast(smem_raw + buf * buf_offset + MT * SMEM_A_STRIDE * sizeof(T)); + }; + + const float absmax_offset_f = absmax_8bit ? __ldg(absmax_offset) : 0.0f; + + const int bm = blockIdx.x * MT; + const int bn = blockIdx.y * NT; + + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int warp_m = warp_id / WARPS_N; + const int warp_n = warp_id % WARPS_N; + + const float* lut = (quant_type == 1) ? FP4_LUT_F32 : NF4_LUT_F32; + const float my_lut_f32 = (lane_id < 16) ? lut[lane_id] : 0.0f; + + const int k_iters = K / K_CHUNK; + const int blocksize_log2 = __ffs(blocksize) - 1; + const int blocks_per_row = K >> blocksize_log2; + + FragC accum[WARP_MMA_M][WARP_MMA_N]; +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) +#pragma unroll + for (int wn = 0; wn < WARP_MMA_N; wn++) { + accum[wm][wn].x[0] = 0.f; + accum[wm][wn].x[1] = 0.f; + accum[wm][wn].x[2] = 0.f; + accum[wm][wn].x[3] = 0.f; + } + + // Tile loading: A (direct copy from global) + B (load packed 4-bit + dequant to half2) + auto load_tile = [&](int k_iter, int buf) { + const int k_base = k_iter * K_CHUNK; + half* __restrict__ sa = smem_a_buf(buf); + half* __restrict__ sb = smem_b_buf(buf); + + // Load A: (MT/32) x uint4 per thread, .ca + constexpr int vecs_per_thread = MT / 32; +#pragma unroll + for (int v = 0; v < vecs_per_thread; v++) { + const int vec_idx = threadIdx.x * vecs_per_thread + v; + const int row = vec_idx / 8; + const int col = (vec_idx % 8) * 8; + const int g_row = bm + row; + uint4 val = {0u, 0u, 0u, 0u}; + if (g_row < M) { + // clang-format off + asm volatile( + "ld.global.ca.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(&A[g_row * K + k_base + col]) + ); + // clang-format on + } + *reinterpret_cast(&sa[row * SMEM_A_STRIDE + col]) = val; + } + + // Load + dequant B + { + const int byte_start = threadIdx.x * B_BYTES; + const int n_local = byte_start / (K_CHUNK / 2); + const int k_byte0 = byte_start % (K_CHUNK / 2); + const int n_global = bn + n_local; + const int k_elem0 = k_byte0 * 2; + + float scale_f = 0.0f; + if (n_global < N) { + const int blk_idx = n_global * blocks_per_row + ((k_base + k_elem0) >> blocksize_log2); + if (absmax_8bit) { + // absmax_8bit[blk_idx] indexes absmax_code (256-entry codebook). + // absmax[blk_idx >> 8] is the fp32 state2 scale (one per 256 blocks). + scale_f = __ldg(&absmax_code[__ldg(&absmax_8bit[blk_idx])]) * __ldg(&absmax[blk_idx >> 8]) + + absmax_offset_f; + } else { + scale_f = __ldg(&absmax[blk_idx]); + } + } + + // Centroid * scale in fp32 avoids double rounding to half. + auto dequant_byte = [&](uint8_t byte, int smem_off) { + const float hi = __shfl_sync(0xffffffff, my_lut_f32, byte >> 4); + const float lo = __shfl_sync(0xffffffff, my_lut_f32, byte & 0x0f); + const auto dq = make_vec2(hi * scale_f, lo * scale_f); + *reinterpret_cast(&sb[n_local * SMEM_B_STRIDE + smem_off]) = + *reinterpret_cast(&dq); + }; + + if constexpr (B_BYTES == 16) { + uint4 packed4 = {0u, 0u, 0u, 0u}; + if (n_global < N) { + // clang-format off + asm volatile( + "ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(packed4.x), "=r"(packed4.y), + "=r"(packed4.z), "=r"(packed4.w) + : "l"(&B[n_global * (K / 2) + k_base / 2 + k_byte0]) + ); + // clang-format on + } + const uint8_t* bytes = reinterpret_cast(&packed4); +#pragma unroll + for (int j = 0; j < 16; j++) + dequant_byte(bytes[j], k_elem0 + j * 2); + } else { + uint2 packed2 = {0u, 0u}; + if (n_global < N) { + // clang-format off + asm volatile( + "ld.global.cs.v2.u32 {%0,%1}, [%2];\n" + : "=r"(packed2.x), "=r"(packed2.y) + : "l"(&B[n_global * (K / 2) + k_base / 2 + k_byte0]) + ); + // clang-format on + } + const uint8_t* bytes = reinterpret_cast(&packed2); +#pragma unroll + for (int j = 0; j < 8; j++) + dequant_byte(bytes[j], k_elem0 + j * 2); + } + } + }; + + // Compute MMA on one buffer. K_CHUNK/MMA_K = 8 iterations. + auto compute = [&](int buf) { + const half* sa = smem_a_buf(buf); + const half* sb = smem_b_buf(buf); + const int wm_off = warp_m * WARP_M; + const int wn_off = warp_n * WARP_N; + + for (int kk = 0; kk < K_CHUNK / MMA_K; kk++) { + const int k_off = kk * MMA_K; + + uint32_t a_frag[WARP_MMA_M][2]; +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) + load_A_frag(a_frag[wm], sa, wm_off + wm * MMA_M, k_off, lane_id); + +#pragma unroll + for (int wn = 0; wn < WARP_MMA_N; wn++) { + uint32_t b_frag[1]; + load_B_frag(b_frag, sb, wn_off + wn * MMA_N, k_off, lane_id); +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) + mma_m16n8k8(accum[wm][wn], a_frag[wm], b_frag); + } + } + }; + + // Main loop: double buffered + load_tile(0, 0); + __syncthreads(); + + for (int k_iter = 0; k_iter < k_iters; k_iter++) { + const int cur_buf = k_iter % 2; + const int next_buf = 1 - cur_buf; + if (k_iter + 1 < k_iters) + load_tile(k_iter + 1, next_buf); + compute(cur_buf); + __syncthreads(); + } + + mma_store_accum( + C, accum, bm, bn, warp_m * WARP_M, warp_n * WARP_N, M, N, lane_id, bias + ); +#endif // __CUDA_ARCH__ == 750 +} + +template +static void launch_tile( + // clang-format off + const T* A, + const uint8_t* B, + const float* absmax, + const uint8_t* absmax_8bit, + const float* absmax_code, + const float* absmax_offset, + T* C, + const T* bias, + int M, int N, int K, + int blocksize, + int quant_type, + cudaStream_t stream + // clang-format on +) { + constexpr int smem = smem_bytes_for(); + static bool cfg = false; + if (!cfg) { + cudaFuncSetAttribute(gemm_4bit_sm75_m16n8k8, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cfg = true; + } + dim3 grid((M + MT - 1) / MT, (N + NT - 1) / NT); + gemm_4bit_sm75_m16n8k8<<>>( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type + ); +} + +/// @brief Auto-dispatch launcher for the sm75 MMA kernel. Selects MT/NT tile +/// based on GPU SM count and shape. +/// @tparam T Input/output dtype (must be `half`) +template +void launch_gemm_4bit_sm75_m16n8k8( + // clang-format off + const T* A, + const uint8_t* B, + const float* absmax, + const uint8_t* absmax_8bit, + const float* absmax_code, + const float* absmax_offset, + T* C, + const T* bias, + int M, int N, int K, + int blocksize, + int quant_type, + GpuProps gpu, + cudaStream_t stream + // clang-format on +) { + static_assert(std::is_same_v, "sm75 MMA requires fp16 (half)"); + + const int num_sms = gpu.num_sms; + const int m_blocks_32 = (M + 31) / 32; + const int m_blocks_64 = (M + 63) / 64; + const int n_blocks_128 = (N + 127) / 128; + const int n_blocks_64 = (N + 63) / 64; + const int blocks_32x64 = m_blocks_32 * n_blocks_64; + + // Suppress MT=32 when M just crossed a 64-row boundary and the extra + // MT=32 block would be nearly empty -- but only once 64x128 reaches 0.25 wave. + // Below 0.25 wave, MT=32 still wins by keeping more blocks in flight. + const bool mt32_boundary_waste = (m_blocks_32 > m_blocks_64) && (m_blocks_64 * n_blocks_128 >= num_sms / 4); + + const bool use_mt32 = (M < 48 && !mt32_boundary_waste) || + (M <= 128 && blocks_32x64 >= num_sms * 2 && m_blocks_32 * 32 < m_blocks_64 * 64 && + m_blocks_64 * n_blocks_128 < num_sms * 3) || + (m_blocks_64 * n_blocks_128 < num_sms / 4); // 64x128 < 0.25 wave + + int m_blocks = use_mt32 ? m_blocks_32 : m_blocks_64; + int mt = use_mt32 ? 32 : 64; + int nt; + + if (mt == 32) { + // NT=128 only at very high occupancy (>=5 waves); NT=64 otherwise gives + // 2x more blocks and wins at normal occupancy on sm75. + const bool use_nt64 = (m_blocks * n_blocks_128 < num_sms * 5) && (n_blocks_64 > n_blocks_128); + nt = use_nt64 ? 64 : 128; + } else { + // Fall back to MT=32 when 64x128 is severely undersubscribed (< 0.25 wave). + if (m_blocks * n_blocks_128 < num_sms / 4 && n_blocks_64 > n_blocks_128) { + m_blocks = m_blocks_32; + mt = 32; + const bool use_nt64 = (m_blocks * n_blocks_128 < num_sms * 5) && (n_blocks_64 > n_blocks_128); + nt = use_nt64 ? 64 : 128; + } else { + // 64x128 otherwise, except 64x64 when NT=128 is below 0.5 wave. + nt = (m_blocks * n_blocks_128 < num_sms / 2) ? 64 : 128; + } + } + + // clang-format off +#define LAUNCH_SM75(MT, NT) \ + launch_tile(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream) + + if (mt == 32 && nt == 64) LAUNCH_SM75(32, 64); + else if (mt == 32 && nt == 128) LAUNCH_SM75(32, 128); + else if (mt == 64 && nt == 64) LAUNCH_SM75(64, 64); + else if (mt == 64 && nt == 128) LAUNCH_SM75(64, 128); +#undef LAUNCH_SM75 + // clang-format on +} + +// Explicit instantiation +template void launch_gemm_4bit_sm75_m16n8k8( + const half*, const uint8_t*, const float*, const uint8_t*, const float*, const float*, half*, const half*, int, int, + int, int, int, GpuProps, cudaStream_t +); diff --git a/csrc/gemm_4bit_sm75.cuh b/csrc/gemm_4bit_sm75.cuh new file mode 100644 index 000000000..c693bb261 --- /dev/null +++ b/csrc/gemm_4bit_sm75.cuh @@ -0,0 +1,16 @@ +#pragma once +// Launcher declarations for the sm75 MMA 4-bit GEMM kernel. +// fp16 only; Turing (sm75) has no bf16 tensor core support. + +#include +#include +#include + +#include "gemm_4bit_common.cuh" + +template +void launch_gemm_4bit_sm75_m16n8k8( + const T* A, const uint8_t* B, const float* absmax, const uint8_t* absmax_8bit, const float* absmax_code, + const float* absmax_offset, T* C, const T* bias, int M, int N, int K, int blocksize, int quant_type, GpuProps gpu, + cudaStream_t stream +); diff --git a/csrc/gemm_4bit_sm80.cu b/csrc/gemm_4bit_sm80.cu new file mode 100644 index 000000000..e4c74460b --- /dev/null +++ b/csrc/gemm_4bit_sm80.cu @@ -0,0 +1,696 @@ +// sm80+ MMA (mma.sync.aligned.m16n8k16) 4-bit GEMM kernel (bf16 and fp16). + +#include +#include +#include +#include +#include + +#include "gemm_4bit_common.cuh" +#include "gemm_4bit_sm80.cuh" + +[[maybe_unused]] static constexpr int MMA_M = 16; +[[maybe_unused]] static constexpr int MMA_N = 8; +[[maybe_unused]] static constexpr int MMA_K = 16; + +static constexpr int NUM_WARPS = 8; +static constexpr int CTA_SIZE = NUM_WARPS * 32; + +/// @brief In-place warp-level MMA: accum += A * B (bf16 or fp16, m16n8k16). +/// Called once per K chunk to accumulate the full matrix product. +/// +/// Executes mma.sync.aligned.m16n8k16.row.col.f32.{bf16,f16}.{bf16,f16}.f32. +/// Fragments are distributed across warp lanes per PTX spec. +/// +/// @tparam T Input dtype; selects bf16 or fp16 PTX instruction +/// @param accum In-place f32 accumulator (4 regs per lane) +/// @param a A fragment: 16x16 operand (4 regs per lane) +/// @param b B fragment: 16x8 operand (2 regs per lane) +template +__device__ __forceinline__ void mma_m16n8k16(FragC& accum, const uint32_t a[4], const uint32_t b[2]) { + if constexpr (std::is_same_v) { + // clang-format off + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%0,%1,%2,%3};\n" + : "+f"(accum.x[0]), "+f"(accum.x[1]), "+f"(accum.x[2]), "+f"(accum.x[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]) + ); + // clang-format on + } else { + // clang-format off + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%0,%1,%2,%3};\n" + : "+f"(accum.x[0]), "+f"(accum.x[1]), "+f"(accum.x[2]), "+f"(accum.x[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]) + ); + // clang-format on + } +} + +// Smem stride = K_CHUNK + 8 elements (same for A and B). +// +8 padding: 16-byte row alignment, limits bank conflicts to 4-way for both +// K_CHUNK=64 (stride=72, gcd(36,32)=4) and K_CHUNK=128 (stride=136, gcd(68,32)=4). + +// A: m16 x k16 from row-major [m][k] smem. ldmatrix.x4. +template +__device__ __forceinline__ void + load_A_frag(uint32_t frag[4], const T* smem_a, int m_off, int k_off, int lane, int stride) { + const int mat_idx = lane / 8; + const int row_in_mat = lane % 8; + const int m_row = m_off + row_in_mat + (mat_idx & 1) * 8; + const int k_col = k_off + (mat_idx >> 1) * 8; + const uint32_t addr = static_cast(__cvta_generic_to_shared(&smem_a[m_row * stride + k_col])); + // clang-format off + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(frag[0]), "=r"(frag[1]), "=r"(frag[2]), "=r"(frag[3]) + : "r"(addr) + ); + // clang-format on +} + +// B: k16 x n8 from row-major [n][k] smem (= col-major [k][n]). ldmatrix.x2. +template +__device__ __forceinline__ void + load_B_frag(uint32_t frag[2], const T* smem_b, int n_off, int k_off, int lane, int stride) { + int n_row = lane % 8; + int k_col = k_off + (lane / 8) * 8; + if (lane >= 16) { + n_row = (lane - 16) % 8; + k_col = k_off + ((lane - 16) / 8) * 8; + } + const uint32_t addr = static_cast(__cvta_generic_to_shared(&smem_b[(n_off + n_row) * stride + k_col])); + // clang-format off + asm volatile( + "ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(frag[0]), "=r"(frag[1]) + : "r"(addr) + ); + // clang-format on +} + +// Smem size per double-buffer in bytes (sizeof(T) == 2 for both bf16 and fp16). +template static constexpr int smem_bytes_for() { + constexpr int STRIDE = K_CHUNK + 8; + return 2 * (MT + NT) * STRIDE * static_cast(sizeof(T)); +} + +/// @brief Fused 4-bit dequantize + MMA GEMM for sm80+ (bf16 and fp16). +/// Computes C[M,N] = A[M,K] @ B[N,K]^T + bias. +/// +/// Layout: +/// A: [M, K] row-major, T (activations) +/// B: [N, K/2] row-major, packed uint8 (2 nibbles per byte, weights) +/// C: [M, N] row-major, T (output) +/// +/// MMA: `mma.sync.aligned.m16n8k16.row.col.f32.{bf16,f16}.{bf16,f16}.f32` +/// A operand: row-major [M, K] +/// B operand: col-major [K, N] (B [N, K] row-major reinterpreted) +/// +/// Double-buffered smem pipeline. Supports optional nested absmax and bias. +/// +/// Smem per double-buffer = 2*(MT+NT)*(K_CHUNK+8)*sizeof(T) bytes: +/// KC=64: 32x 64 27KB 32x128 45KB 64x 64 36KB +/// 64x128 54KB 128x 64 54KB 128x128 72KB +/// KC=128: 32x 64 51KB 32x128 85KB 32x256 153KB +/// 64x 32 51KB 64x 64 68KB +/// +/// @tparam T Input/output dtype (`__nv_bfloat16` or `half`) +/// @tparam MT M tile size (32, 64, or 128) +/// @tparam NT N tile size (32, 64, 128, or 256) +/// @tparam K_CHUNK K elements per outer iteration (64 or 128) +template +__global__ void __launch_bounds__(CTA_SIZE) gemm_4bit_sm80_m16n8k16( + // clang-format off + const T* __restrict__ A, // inputs [M, K] + const uint8_t* __restrict__ B, // packed 4-bit weights [N, K/2] + const float* __restrict__ absmax, // fp32 absmax [N, K/blocksize] or + // [ceil(N*K/(blocksize*256))] when nested + const uint8_t* __restrict__ absmax_8bit, // [N, K/blocksize] uint8 compressed absmax; + // nullptr = non-nested + const float* __restrict__ absmax_code, // [256] codebook for 8bit absmax + const float* __restrict__ absmax_offset, // scalar; nullptr = non-nested + T* __restrict__ C, // [M, N] + const T* __restrict__ bias, // [N] optional, nullptr = no bias + int M, int N, int K, // problem shape + int blocksize, // elements per quantization block + int quant_type // 1 = FP4, 2 = NF4 + // clang-format on +) { +#if __CUDA_ARCH__ >= 800 + static_assert(MT == 32 || MT == 64 || MT == 128, "MT must be 32, 64, or 128"); + static_assert(NT == 32 || NT == 64 || NT == 128 || NT == 256, "NT must be 32, 64, 128, or 256"); + static_assert(K_CHUNK == 64 || K_CHUNK == 128, "K_CHUNK must be 64 or 128"); + + // Trap on tile+arch combinations that the dispatcher can never reach. + // if constexpr prunes the kernel body entirely at compile time; __trap() + // provides a visible error if this assumption is ever violated at runtime. + // clang-format off +#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 + // HBM (sm90/sm100/sm103): MT=32 always returns K_CHUNK=128 with NT=64 or NT=256; + // MT=64 never uses NT=32 (GDDR-only fallback). + // sm80 is intentionally excluded: these four GDDR tiles are unused on sm80 itself + // but must remain in the sm80 cubin so it can run correctly on sm86/sm89. + if constexpr ((MT==32 && NT== 64 && K_CHUNK== 64) || + (MT==32 && NT==128 && K_CHUNK== 64) || + (MT==32 && NT==128 && K_CHUNK==128) || + (MT==64 && NT== 32 && K_CHUNK==128)) { __trap(); return; } +#endif +#if __CUDA_ARCH__ == 900 + // sm90: only 64x64-64 is dispatched at MT=64 (NT forced to 64, K_CHUNK=128 never selected). + if constexpr (MT==64 && !(NT==64 && K_CHUNK==64)) { __trap(); return; } +#endif +#if __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 + // sm100/sm103 (B200/B300): only 64x64-128 is dispatched at MT=64. + if constexpr (MT==64 && !(NT==64 && K_CHUNK==128)) { __trap(); return; } +#endif +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210 + // GDDR (sm86/sm89/sm120/sm121): NT=256 only exists in the HBM MT=32 path. + if constexpr (MT==32 && NT==256 && K_CHUNK==128) { __trap(); return; } +#endif + // clang-format on + + const float absmax_offset_f = absmax_8bit ? __ldg(absmax_offset) : 0.0f; + + // K_CHUNK + 8 padding: 16-byte row alignment, limits bank conflicts to 4-way. + constexpr int SMEM_A_STRIDE = K_CHUNK + 8; + constexpr int SMEM_B_STRIDE = K_CHUNK + 8; + + using WL = MmaWarpLayout; + constexpr int WARPS_M = WL::WARPS_M; + constexpr int WARPS_N = WL::WARPS_N; + constexpr int WARP_M = WL::WARP_M; + constexpr int WARP_MMA_M = WL::WARP_MMA_M; + constexpr int WARP_N = WL::WARP_N; + constexpr int WARP_MMA_N = WL::WARP_MMA_N; + + static_assert(MT >= WARPS_M * MMA_M, "MT too small for warp layout"); + static_assert(NT >= WARPS_N * MMA_N, "NT too small for warp layout"); + + // Packed bytes of B per thread per K-chunk: + // NT= 64, KC= 64 -> 8 bytes (uint2) + // NT=128, KC= 64 -> 16 bytes (uint4) + // NT= 32, KC=128 -> 8 bytes (uint2) + // NT= 64, KC=128 -> 16 bytes (uint4) + // NT=128, KC=128 -> 32 bytes (2x uint4) + // NT=256, KC=128 -> 64 bytes (4x uint4) + constexpr int B_BYTES = NT * (K_CHUNK / 2) / CTA_SIZE; + static_assert(B_BYTES == 8 || B_BYTES == 16 || B_BYTES == 32 || B_BYTES == 64, "unexpected B bytes per thread"); + + extern __shared__ char smem_raw[]; + constexpr int buf_offset = (MT * SMEM_A_STRIDE + NT * SMEM_B_STRIDE) * sizeof(T); + auto smem_a_buf = [&](int buf) -> T* { return reinterpret_cast(smem_raw + buf * buf_offset); }; + auto smem_b_buf = [&](int buf) -> T* { + return reinterpret_cast(smem_raw + buf * buf_offset + MT * SMEM_A_STRIDE * sizeof(T)); + }; + + const int bm = blockIdx.x * MT; + const int bn = blockIdx.y * NT; + + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int warp_m = warp_id / WARPS_N; + const int warp_n = warp_id % WARPS_N; + + // LUT in fp32: centroid * scale in fp32 avoids double rounding to output dtype. + const float* lut = (quant_type == 1) ? FP4_LUT_F32 : NF4_LUT_F32; + const float my_lut_f32 = (lane_id < 16) ? lut[lane_id] : 0.0f; + + const int k_iters = K / K_CHUNK; + const int blocksize_log2 = __ffs(blocksize) - 1; + const int blocks_per_row = K >> blocksize_log2; + + FragC accum[WARP_MMA_M][WARP_MMA_N]; +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) +#pragma unroll + for (int wn = 0; wn < WARP_MMA_N; wn++) { + accum[wm][wn].x[0] = 0.f; + accum[wm][wn].x[1] = 0.f; + accum[wm][wn].x[2] = 0.f; + accum[wm][wn].x[3] = 0.f; + } + + // Tile loading: A (direct copy from global) + B (load packed 4-bit + dequant) + auto load_tile = [&](int k_iter, int buf) { + const int k_base = k_iter * K_CHUNK; + T* __restrict__ sa = smem_a_buf(buf); + T* __restrict__ sb = smem_b_buf(buf); + + // Load A: each thread loads vecs_per_thread uint4 (8 T elements each). + constexpr int vecs_per_row = K_CHUNK / 8; + constexpr int vecs_per_thread = MT * vecs_per_row / CTA_SIZE; +#pragma unroll + for (int v = 0; v < vecs_per_thread; v++) { + const int vec_idx = threadIdx.x * vecs_per_thread + v; + const int row = vec_idx / vecs_per_row; + const int col = (vec_idx % vecs_per_row) * 8; + const int g_row = bm + row; + uint4 val = {0u, 0u, 0u, 0u}; + if (g_row < M) { + // clang-format off + asm volatile( + "ld.global.ca.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(&A[g_row * K + k_base + col]) + ); + // clang-format on + } + *reinterpret_cast(&sa[row * SMEM_A_STRIDE + col]) = val; + } + + // Pre-fetch nested_idx_b before scale_f computation; the A-tile loads above + // provide enough in-flight latency to hide the absmax_8bit read. + const int byte_start_b = threadIdx.x * B_BYTES; + const int n_local_b = byte_start_b / (K_CHUNK / 2); + const int k_byte0_b = byte_start_b % (K_CHUNK / 2); + const int n_global_b = bn + n_local_b; + const int k_elem0_b = k_byte0_b * 2; + const int blk_idx_b = n_global_b * blocks_per_row + ((k_base + k_elem0_b) >> blocksize_log2); + uint8_t nested_idx_b = 0; + if (absmax_8bit && n_global_b < N) + nested_idx_b = __ldg(&absmax_8bit[blk_idx_b]); + + // Load + dequant B + { + const int n_local = n_local_b; + const int k_byte0 = k_byte0_b; + const int n_global = n_global_b; + const int k_elem0 = k_elem0_b; + // (k_base + k_elem0) selects the correct absmax block for any blocksize: + // for blocksize >= K_CHUNK the index collapses to k_iter; for smaller + // blocksizes each thread's k_elem0 selects the right sub-chunk block. + const int blk_idx = blk_idx_b; + + float scale_f = 0.0f; + if (n_global < N) { + if (absmax_8bit) { + // nested_idx_b was pre-fetched above; __ldg keeps the 1KB codebook in read-only L1. + scale_f = __ldg(&absmax_code[nested_idx_b]) * __ldg(&absmax[blk_idx >> 8]) + absmax_offset_f; + } else { + scale_f = __ldg(&absmax[blk_idx]); + } + } + + // fp32 centroid * scale avoids double rounding; hi nibble (>>4) = lower K index. + auto dequant_byte = [&](uint8_t byte, int smem_off) { + const float hi = __shfl_sync(0xffffffff, my_lut_f32, byte >> 4); + const float lo = __shfl_sync(0xffffffff, my_lut_f32, byte & 0x0f); + const auto dq = make_vec2(hi * scale_f, lo * scale_f); + *reinterpret_cast(&sb[n_local * SMEM_B_STRIDE + smem_off]) = + *reinterpret_cast(&dq); + }; + + if constexpr (B_BYTES == 64) { + // NT=256, KC=128: each thread covers one N-column x K_CHUNK K-elements. + // scale_f for the first block is fetched above. The loop refreshes it at + // each absmax block boundary using (j*2 & (blocksize-1)) == 0, which is + // a single bitmask AND per unrolled step since blocksize is always a power of 2. + uint4 p0 = {0, 0, 0, 0}, p1 = {0, 0, 0, 0}, p2 = {0, 0, 0, 0}, p3 = {0, 0, 0, 0}; + const uint8_t* bptr = &B[n_global * (K / 2) + k_base / 2 + k_byte0]; + if (n_global < N) { + // clang-format off + asm volatile( + "ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%16];\n" + "ld.global.cs.v4.u32 {%4,%5,%6,%7}, [%16+16];\n" + "ld.global.cs.v4.u32 {%8,%9,%10,%11}, [%16+32];\n" + "ld.global.cs.v4.u32 {%12,%13,%14,%15},[%16+48];\n" + : "=r"(p0.x), "=r"(p0.y), "=r"(p0.z), "=r"(p0.w), + "=r"(p1.x), "=r"(p1.y), "=r"(p1.z), "=r"(p1.w), + "=r"(p2.x), "=r"(p2.y), "=r"(p2.z), "=r"(p2.w), + "=r"(p3.x), "=r"(p3.y), "=r"(p3.z), "=r"(p3.w) + : "l"(bptr) + ); + // clang-format on + } + uint8_t bytes[64]; + memcpy(bytes, &p0, 16); + memcpy(bytes + 16, &p1, 16); + memcpy(bytes + 32, &p2, 16); + memcpy(bytes + 48, &p3, 16); +#pragma unroll + for (int j = 0; j < B_BYTES; j++) { + if (j > 0 && ((j * 2) & (blocksize - 1)) == 0 && n_global < N) { + const int blk_idx = n_global * blocks_per_row + ((k_base + k_elem0 + j * 2) >> blocksize_log2); + if (absmax_8bit) { + scale_f = __ldg(&absmax_code[__ldg(&absmax_8bit[blk_idx])]) * __ldg(&absmax[blk_idx >> 8]) + + absmax_offset_f; + } else { + scale_f = __ldg(&absmax[blk_idx]); + } + } + dequant_byte(bytes[j], k_elem0 + j * 2); + } + } else if constexpr (B_BYTES == 32) { + // NT=128, KC=128: same block boundary refresh as B_BYTES=64. + // Needed when blocksize < K_CHUNK (e.g. blocksize=32 with KC=128). + uint4 p0 = {0, 0, 0, 0}, p1 = {0, 0, 0, 0}; + const uint8_t* bptr = &B[n_global * (K / 2) + k_base / 2 + k_byte0]; + if (n_global < N) { + // clang-format off + asm volatile( + "ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%8];\n" + "ld.global.cs.v4.u32 {%4,%5,%6,%7}, [%8+16];\n" + : "=r"(p0.x), "=r"(p0.y), "=r"(p0.z), "=r"(p0.w), + "=r"(p1.x), "=r"(p1.y), "=r"(p1.z), "=r"(p1.w) + : "l"(bptr) + ); + // clang-format on + } + uint8_t bytes[32]; + memcpy(bytes, &p0, 16); + memcpy(bytes + 16, &p1, 16); +#pragma unroll + for (int j = 0; j < B_BYTES; j++) { + if (j > 0 && ((j * 2) & (blocksize - 1)) == 0 && n_global < N) { + const int blk_idx = n_global * blocks_per_row + ((k_base + k_elem0 + j * 2) >> blocksize_log2); + if (absmax_8bit) { + scale_f = __ldg(&absmax_code[__ldg(&absmax_8bit[blk_idx])]) * __ldg(&absmax[blk_idx >> 8]) + + absmax_offset_f; + } else { + scale_f = __ldg(&absmax[blk_idx]); + } + } + dequant_byte(bytes[j], k_elem0 + j * 2); + } + } else if constexpr (B_BYTES == 16) { + uint4 packed4 = {0u, 0u, 0u, 0u}; + if (n_global < N) { + // clang-format off + asm volatile( + "ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(packed4.x), "=r"(packed4.y), + "=r"(packed4.z), "=r"(packed4.w) + : "l"(&B[n_global * (K / 2) + k_base / 2 + k_byte0]) + ); + // clang-format on + } + const uint8_t* bytes = reinterpret_cast(&packed4); +#pragma unroll + for (int j = 0; j < 16; j++) + dequant_byte(bytes[j], k_elem0 + j * 2); + } else if constexpr (B_BYTES == 8) { + uint2 packed2 = {0u, 0u}; + if (n_global < N) { + // clang-format off + asm volatile( + "ld.global.cs.v2.u32 {%0,%1}, [%2];\n" + : "=r"(packed2.x), "=r"(packed2.y) + : "l"(&B[n_global * (K / 2) + k_base / 2 + k_byte0]) + ); + // clang-format on + } + const uint8_t* bytes = reinterpret_cast(&packed2); +#pragma unroll + for (int j = 0; j < 8; j++) + dequant_byte(bytes[j], k_elem0 + j * 2); + } + } + }; + + // Compute MMA on one buffer + auto compute = [&](int buf) { + const T* sa = smem_a_buf(buf); + const T* sb = smem_b_buf(buf); + const int wm_off = warp_m * WARP_M; + const int wn_off = warp_n * WARP_N; + + for (int kk = 0; kk < K_CHUNK / MMA_K; kk++) { + const int k_off = kk * MMA_K; + uint32_t a_frag[WARP_MMA_M][4]; +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) + load_A_frag(a_frag[wm], sa, wm_off + wm * MMA_M, k_off, lane_id, SMEM_A_STRIDE); +#pragma unroll + for (int wn = 0; wn < WARP_MMA_N; wn++) { + uint32_t b_frag[2]; + load_B_frag(b_frag, sb, wn_off + wn * MMA_N, k_off, lane_id, SMEM_B_STRIDE); +#pragma unroll + for (int wm = 0; wm < WARP_MMA_M; wm++) + mma_m16n8k16(accum[wm][wn], a_frag[wm], b_frag); + } + } + }; + + // Main loop: double buffered + load_tile(0, 0); + __syncthreads(); + + for (int k_iter = 0; k_iter < k_iters; k_iter++) { + const int cur_buf = k_iter % 2; + const int next_buf = 1 - cur_buf; + if (k_iter + 1 < k_iters) + load_tile(k_iter + 1, next_buf); + compute(cur_buf); + __syncthreads(); + } + + mma_store_accum( + C, accum, bm, bn, warp_m * WARP_M, warp_n * WARP_N, M, N, lane_id, bias + ); +#endif // __CUDA_ARCH__ >= 800 +} + +template +static void launch_tile( + // clang-format off + const T* A, + const uint8_t* B, + const float* absmax, + const uint8_t* absmax_8bit, + const float* absmax_code, + const float* absmax_offset, + T* C, + const T* bias, + int M, int N, int K, + int blocksize, + int quant_type, + cudaStream_t stream + // clang-format on +) { + constexpr int smem = smem_bytes_for(); + static bool cfg = false; + if (!cfg) { + cudaFuncSetAttribute(gemm_4bit_sm80_m16n8k16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cfg = true; + } + dim3 grid((M + MT - 1) / MT, (N + NT - 1) / NT); + gemm_4bit_sm80_m16n8k16<<>>( + A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type + ); +} + +/// @brief Auto-dispatch launcher for the sm80+ MMA kernel. Selects MT/NT/KC tile +/// based on GPU arch, SM count, and shape. +/// @tparam T Input/output dtype (`__nv_bfloat16` or `half`) +template +void launch_gemm_4bit_sm80_m16n8k16( + // clang-format off + const T* A, + const uint8_t* B, + const float* absmax, + const uint8_t* absmax_8bit, + const float* absmax_code, + const float* absmax_offset, + T* C, + const T* bias, + int M, int N, int K, + int blocksize, + int quant_type, + GpuProps gpu, + cudaStream_t stream + // clang-format on +) { + const int num_sms = gpu.num_sms; + const int cc_major = gpu.cc_major; + const int cc_minor = gpu.cc_minor; + + const bool hbm_arch = (cc_major == 8 && cc_minor == 0) || cc_major == 9 || cc_major == 10; + const bool sm86 = cc_major == 8 && cc_minor == 6; + // A10-class sm86 cards run ~600 GB/s; RTX 3090 (82 SMs) and A40 (84 SMs) are above + // the 72-SM threshold and have higher bandwidth (behavior not validated here). + const bool sm86_low_bw = sm86 && num_sms <= 72; + // RTX Pro 6000 (sm120, 188 SMs) is the only validated high-SM sm120 card. + const bool high_sm_sm120 = cc_major == 12 && num_sms >= 150; + + const int m_blocks_32 = (M + 31) / 32; + const int m_blocks_64 = (M + 63) / 64; + const int m_blocks_128 = (M + 127) / 128; + const int n_blocks_64 = (N + 63) / 64; + const int n_blocks_128 = (N + 127) / 128; + + const int blocks_32x64 = m_blocks_32 * n_blocks_64; + const int blocks_64x64 = m_blocks_64 * n_blocks_64; + const int blocks_128x128 = m_blocks_128 * n_blocks_128; + + // MT=128 wastes M-rows when M % 128 is in [1, 64]: the last tile is <=50% full. + const bool mt128_row_waste = (M % 128 != 0) && (M % 128 <= 64); + // MT=128 is adequate at >=3/4 wave. At M=96 on A10 (N=8192: 0.89 waves), + // MT=128 compute efficiency beats MT=32 despite modest under-subscription. + const bool mt128_adequate = blocks_128x128 * 4 >= num_sms * 3; + // Suppress MT=32 when M just crossed a 64-row boundary and the extra MT=32 block is + // nearly empty. Once blocks_64x64 >= num_sms (1 wave with 64x64 tiles), MT=64 fills cleanly. + const bool mt32_boundary_waste = (m_blocks_32 > m_blocks_64) && (blocks_64x64 >= num_sms); + + const bool use_mt32 = + (M < 48 && !mt32_boundary_waste) || (M <= 128 && !(mt32_boundary_waste && mt128_adequate) && + blocks_32x64 >= num_sms * 2 && m_blocks_32 * 32 < m_blocks_64 * 64); + + const bool use_mt64 = !use_mt32 && (mt128_row_waste || !mt128_adequate) && blocks_64x64 > blocks_128x128; + + struct Tile { + int mt, nt, kc; + }; + + auto select_tile = [&]() -> Tile { + int mt = use_mt32 ? 32 : (use_mt64 ? 64 : 128); + + if (mt == 64) { + // sm86 low-BW, tall-K (K>=N) at >=0.5 wave: MT=128 amortizes B-matrix bandwidth + // more efficiently than MT=64. + // Calibrated on A10: 128x128 beats 64x128 at M=97-128 for tall-K. + if (sm86_low_bw && K >= N && !mt128_row_waste && blocks_128x128 * 2 >= num_sms) + mt = 128; + // High-SM sm120 below 1 wave: MT=64 is severely under-subscribed. + // MT=32 doubles M-block count and recovers occupancy. + // Calibrated on RTX Pro 6000: N=8192,K=8192 M=63-64. + else if (high_sm_sm120 && blocks_64x64 < num_sms) + mt = 32; + // HBM narrow-N below 0.5 wave. + // Calibrated on A100: N=512,K=4096 M=48-384. + else if (hbm_arch && blocks_64x64 < num_sms / 2) + mt = 32; + // Short-K small-weight (K*2=1 wave and wins when more M-blocks are in flight. + else if (K * 2 < N && (long long)N * K < 4LL * 1024 * 1024 && blocks_32x64 >= num_sms) + mt = 32; + // GDDR short-K at 1/4-3/4 wave: 32x64-128 outperforms MT=64 because the shorter + // K-loop makes KC=128 ILP more valuable than a wider output tile. + // num_sms >= 60 excludes L4 (58 SMs) where KC=64 competes. + else if (!hbm_arch && K * 2 < N && num_sms >= 60 && blocks_64x64 >= num_sms / 4 && + blocks_64x64 < num_sms * 3 / 4) + return {32, 64, 128}; + } + + if (mt == 32) { + if (hbm_arch) { + // HBM MT=32: always KC=128 (longer chunks hide HBM latency). + // NT=256 in the 3/4-to-1-wave window; NT=64's 4x block advantage + // takes over above 1 wave. + // Calibrated on A100/H100/H200: N=36864 (>1 wave), 32x64 beats 32x256. + const int n_blocks_256 = (N + 255) / 256; + const int blocks_32x256 = m_blocks_32 * n_blocks_256; + const int nt = (blocks_32x256 * 4 >= num_sms * 3 && blocks_32x256 <= num_sms) ? 256 : 64; + return {32, nt, 128}; + } + + // GDDR MT=32: KC and NT driven by occupancy and register pressure. + if (m_blocks_32 >= 2 && blocks_32x64 > num_sms) { + // M>=33, >1 wave: register pressure at this occupancy favors KC=64. + // NT=128 in the 1-2 wave window on sm86 or short-K (K*3<=N). + if (m_blocks_32 >= 3) { + const int blocks_32x128 = m_blocks_32 * n_blocks_128; + if (blocks_32x128 > num_sms && blocks_32x128 <= num_sms * 2 && (sm86 || K * 3 <= N)) + return {32, 128, 64}; + } + return {32, 64, 64}; + } + if (m_blocks_32 == 1) { + // M<33, just above 1 wave (1.0-1.2x): KC=64 wins in this narrow window. + if (blocks_32x64 > num_sms && blocks_32x64 < num_sms + num_sms / 5) + return {32, 64, 64}; + // M<33, >3 waves on sm86: NT=128 + KC=64. + // Calibrated on A10: N=14336-36864 M=5-32; not validated on sm89/sm120. + if (sm86 && blocks_32x64 > num_sms * 3) + return {32, 128, 64}; + } + // 64x32-128: NT=32 gives 2x more N-blocks than NT=64 when occupancy-limited. + // Excluded for high-SM sm120 (no calibrated wins on RTX Pro 6000). + if (!high_sm_sm120) { + const int blocks_64x32 = m_blocks_64 * ((N + 31) / 32); + if (blocks_64x32 <= num_sms) + return {64, 32, 128}; + } + // KC=128 fallback: NT=128 when >=2/3 wave, NT=64 otherwise. + const int nt = (m_blocks_32 * n_blocks_128 * 3 >= num_sms * 2) ? 128 : 64; + return {32, nt, 128}; + } + + if (mt == 64) { + // sm90/sm100: no calibrated wins for 64x128; default to NT=64. + // A100 and GDDR retain NT=128 when well-subscribed. + int nt; + if (cc_major == 9 || cc_major == 10) + nt = 64; + else if (m_blocks_64 * n_blocks_128 < num_sms) + nt = 64; + else + nt = 128; + // KC=128 for sm100 (Blackwell) or GDDR tall-K (K>=N). + // sm80: no calibrated wins for KC=128 at MT=64. + if (nt == 64 && (cc_major == 10 || (!hbm_arch && K >= N))) + return {64, 64, 128}; + return {64, nt, 64}; + } + + // MT=128: KC=128 not dispatched (needs M>768 benchmarks to calibrate HBM crossover). + // NT=128 threshold is halved on sm86 low-BW vs other arches. + // Calibrated on A10: N=8192,K=2048 M=127, NT=128 wins 1.49x vs NT=64. + const int nt128_min_wave = sm86_low_bw ? num_sms / 2 : num_sms; + const int nt = (blocks_128x128 >= nt128_min_wave) ? 128 : 64; + return {128, nt, 64}; + }; + + auto [mt, nt, kc] = select_tile(); + + // KC=128 requires K%128==0 (k_iters = K/K_CHUNK truncates the remainder). + // Remap to the best valid KC=64 tile for this arch without changing any __trap() guards. + if (kc == 128 && K % 128 != 0) { + kc = 64; + if (hbm_arch) { + // MT=32 KC=64 tiles are trapped on all HBM arches; sm100 also traps MT=64 KC=64. + nt = 64; + mt = (cc_major == 10) ? 128 : 64; + } else if (nt == 32) { + nt = 64; // no 64x32-64 tile + } + } + + // clang-format off +#define LAUNCH_SM80(MT, NT, KC) \ + launch_tile(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream) + + if (kc == 64) { + if (mt == 32 && nt == 64) LAUNCH_SM80( 32, 64, 64); + else if (mt == 32 && nt == 128) LAUNCH_SM80( 32, 128, 64); + else if (mt == 64 && nt == 64) LAUNCH_SM80( 64, 64, 64); + else if (mt == 64 && nt == 128) LAUNCH_SM80( 64, 128, 64); + else if (mt == 128 && nt == 64) LAUNCH_SM80(128, 64, 64); + else if (mt == 128 && nt == 128) LAUNCH_SM80(128, 128, 64); + } else { + if (mt == 32 && nt == 64) LAUNCH_SM80( 32, 64, 128); + else if (mt == 32 && nt == 128) LAUNCH_SM80( 32, 128, 128); + else if (mt == 32 && nt == 256) LAUNCH_SM80( 32, 256, 128); + else if (mt == 64 && nt == 32) LAUNCH_SM80( 64, 32, 128); + else if (mt == 64 && nt == 64) LAUNCH_SM80( 64, 64, 128); + // else if (mt == 64 && nt == 128) LAUNCH_SM80( 64, 128, 128); // unreachable: KC=128 at MT=64 requires NT=64 + // else if (mt == 128 && nt == 64) LAUNCH_SM80(128, 64, 128); // unreachable: MT=128 always dispatches KC=64 + // else if (mt == 128 && nt == 128) LAUNCH_SM80(128, 128, 128); // unreachable: same + } +#undef LAUNCH_SM80 + // clang-format on +} + +// Explicit instantiations +template void launch_gemm_4bit_sm80_m16n8k16<__nv_bfloat16>( + const __nv_bfloat16*, const uint8_t*, const float*, const uint8_t*, const float*, const float*, __nv_bfloat16*, + const __nv_bfloat16*, int, int, int, int, int, GpuProps, cudaStream_t +); +template void launch_gemm_4bit_sm80_m16n8k16( + const half*, const uint8_t*, const float*, const uint8_t*, const float*, const float*, half*, const half*, int, int, + int, int, int, GpuProps, cudaStream_t +); diff --git a/csrc/gemm_4bit_sm80.cuh b/csrc/gemm_4bit_sm80.cuh new file mode 100644 index 000000000..56c7c201b --- /dev/null +++ b/csrc/gemm_4bit_sm80.cuh @@ -0,0 +1,16 @@ +#pragma once +// Launcher declarations for the sm80+ MMA 4-bit GEMM kernel. + +#include +#include +#include +#include + +#include "gemm_4bit_common.cuh" + +template +void launch_gemm_4bit_sm80_m16n8k16( + const T* A, const uint8_t* B, const float* absmax, const uint8_t* absmax_8bit, const float* absmax_code, + const float* absmax_offset, T* C, const T* bias, int M, int N, int K, int blocksize, int quant_type, GpuProps gpu, + cudaStream_t stream +); diff --git a/pyproject.toml b/pyproject.toml index 745f74df4..cf9e00708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,9 @@ ignore = [ "bitsandbytes/**/triton/**/*.py" = [ "I001", # import order ] +"bitsandbytes/backends/utils.py" = [ + "I001", # import order +] [tool.ruff.lint.isort] combine-as-imports = true diff --git a/tests/test_autograd.py b/tests/test_autograd.py index deaf93785..d150f4735 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -126,39 +126,38 @@ def test_matmullt( if req_grad[0]: torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[2]: + if req_grad[2] and req_grad[0]: torch.testing.assert_close(gradBias1, gradBias2) @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4")) -@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) -@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("transpose_B", TRUE_FALSE, ids=id_formatter("transpose_B")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) def test_matmul_4bit( device, - dim1, dim2, dim3, dim4, - funcs, dtype, req_grad, - transpose, + transpose_B, has_bias, compress_statistics, quant_type, ): - dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) - dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - if has_bias == False: + # transpose_B=True: B is [N, K] -- standard convention + # transpose_B=False: B is [K, N] -- deprecated orientation, emits DeprecationWarning + dimA = (dim2, dim3) + dimB = (dim4, dim3) if transpose_B else (dim3, dim4) + + if not has_bias: req_grad = list(req_grad) req_grad[2] = False @@ -169,71 +168,69 @@ def test_matmul_4bit( pytest.skip("This configuration is not supported on HPU.") for i in range(3): - # normal multiply - if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype) - B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype) - bias = None - bias2 = None + A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + + B2, quant_state = bnb.functional.quantize_4bit( + B, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + + if transpose_B: + out_torch = torch.matmul(A, B.t()) + out_bnb = bnb.matmul_4bit(A, B2, quant_state, bias=bias2) + else: + out_torch = torch.matmul(A, B) + with pytest.warns(DeprecationWarning): + out_bnb = bnb.matmul_4bit(A, B2, quant_state, bias=bias2) + + if has_bias: + out_torch = out_torch + bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + if device == "cuda": + torch.cuda.synchronize() + elif device == "hpu": + torch.hpu.synchronize() + + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None if has_bias: - bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2]) - bias2 = bias.clone() - torch.nn.init.xavier_uniform_(B) - - B2, quant_state = bnb.functional.quantize_4bit( - B, - compress_statistics=compress_statistics, - quant_type=quant_type, - ) - - if not transpose[0] and transpose[1]: - out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) - elif not transpose[0] and not transpose[1]: - out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2, quant_state, bias=bias2) - + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None if has_bias: - out_torch += bias + gradBias2 = bias.grad + bias.grad = None - assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" - - n = out_bnb.numel() - err = torch.abs(out_bnb - out_torch).float().mean().item() - if n > 0: - assert err < 0.115 - - # assert err < 0.20 - if any(req_grad): - out_bnb.data.copy_(out_torch) - if device == "cuda": - torch.cuda.synchronize() - elif device == "hpu": - torch.hpu.synchronize() - - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() - loss_bnb.backward() - gradA1 = A.grad - gradB1 = B.grad - A.grad = None - B.grad = None - if has_bias: - gradBias1 = bias.grad - bias.grad = None - - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() - loss_torch.backward() - gradA2 = A.grad - gradB2 = B.grad - A.grad = None - B.grad = None - if has_bias: - gradBias2 = bias.grad - bias.grad = None + if req_grad[0]: + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - - if req_grad[2]: - torch.testing.assert_close(gradBias1, gradBias2) + if req_grad[2]: + torch.testing.assert_close(gradBias1, gradBias2) diff --git a/tests/test_functional.py b/tests/test_functional.py index a2127c58a..62585de0c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -20,6 +20,21 @@ torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) k = 20 +_GEMM_4BIT_SHAPES = [ + (1, 256, 128), # SIMT path on all GPUs + (40, 2048, 128), # 64x32-128 + sm75 tile + (9, 5120, 128), # 32x64-128 on L40S/4090 + (40, 5120, 128), # 32x64-64 on A10/L40S/4090 + (80, 8192, 128), # 32x128-64 on L40S/4090 + (7, 14336, 128), # many tiles depending on GPU + (80, 14336, 128), # 128x64-64 on L40S/4090/H100 + (80, 21504, 128), # 128x128-64 across many GPUs + (4, 28672, 128), # sm75 32x128 + HBM 32x256-128 + (40, 36864, 128), # 64x128-64 wide-N + (48, 8192, 512), # 64x64-64 on GDDR+A100, K=512 + (48, 8192, 8192), # 64x64-128, large K +] + def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): idx = torch.isclose(a, b, rtol=rtol, atol=atol) @@ -798,8 +813,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind): qB, state = F._convert_weight_packed_for_cpu(qB, state) qB = qB.t() C2 = F.gemv_4bit(A, qB.t(), state=state) - A.requires_grad = True - C1 = bnb.matmul_4bit(A, qB.t(), state) + # dequant+F.linear reference path + C1 = torch.nn.functional.linear(A, F.dequantize_4bit(qB, state).to(dtype)) err1 = (C1 - C2).abs().float() err2 = (C3 - C2).abs().float() @@ -932,6 +947,61 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype): # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) + @pytest.mark.filterwarnings("ignore:inner dimension:UserWarning") + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 4096], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) + @pytest.mark.parametrize("MNK", _GEMM_4BIT_SHAPES, ids=[f"M{m}N{n}K{k}" for m, n, k in _GEMM_4BIT_SHAPES]) + def test_matmul_4bit(self, MNK, dtype, blocksize, quant_type, compress_statistics, has_bias, device): + M, N, K = MNK + if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, torch.uint8): + pytest.skip("This configuration is not supported on HPU.") + if device == "cpu" and (blocksize == 4096 or K > 128 or M > 40 or N > 8192): + pytest.skip("narrowed on CPU") + + B = torch.randn(N, K, dtype=dtype, device=device) / (K**0.5) + A = torch.randn(1, M, K, dtype=dtype, device=device) + bias = torch.randn(N, dtype=dtype, device=device) if has_bias else None + ref = torch.nn.functional.linear(A, B, bias) + + qB, qs = F.quantize_4bit( + B, blocksize=blocksize, quant_type=quant_type, compress_statistics=compress_statistics + ) + out = bnb.matmul_4bit(A, qB, qs, bias=bias) + + mean_err = (ref.float() - out.float()).abs().mean().item() + if quant_type == "nf4" or blocksize <= 64: + threshold = 0.115 + elif blocksize <= 256: + threshold = 0.13 + else: + threshold = 0.16 + assert mean_err < threshold + + @pytest.mark.parametrize("device", get_available_devices()) + def test_matmul_4bit_weight_orientation(self, device): + N, K = 256, 128 + dtype = torch.float16 + A = torch.randn(1, 4, K, dtype=dtype, device=device) + B = torch.randn(N, K, dtype=dtype, device=device) / (K**0.5) + ref = torch.nn.functional.linear(A, B) + qB, qs = F.quantize_4bit(B, blocksize=64, quant_type="nf4") + + # B.t() and canonical B must give identical results. + out_canonical = bnb.matmul_4bit(A, qB, qs) + out_transposed = bnb.matmul_4bit(A, qB.t(), qs) + torch.testing.assert_close(out_canonical, out_transposed) + + # [K, N] quant_state emits DeprecationWarning and still produces correct output. + B_kn = B.t().contiguous() + qB_kn, qs_kn = F.quantize_4bit(B_kn, blocksize=64, quant_type="nf4") + with pytest.warns(DeprecationWarning, match="Re-quantize from the weight"): + out_kn = bnb.matmul_4bit(A, qB_kn, qs_kn) + assert (ref.float() - out_kn.float()).abs().mean().item() < 0.115 + def test_normal_map_tree(): code = F.create_normal_map() diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ee1433641..12ed0eb27 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -355,10 +355,13 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) +@pytest.mark.parametrize("batch_size", [1, 16], ids=id_formatter("batch_size")) @pytest.mark.skipif( torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" ) -def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): +def test_linear4bit_torch_compile( + device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode, batch_size +): if device == "hpu" and not is_supported_on_hpu(quant_type): pytest.skip("This configuration is not supported on HPU.") @@ -395,7 +398,6 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st pytest.xfail("precision diverges on macos cpu") dim = 256 - batch_size = 16 torch.compiler.reset() diff --git a/tests/test_ops.py b/tests/test_ops.py index 3d8461f0d..bd5217748 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,7 +4,7 @@ import torch import bitsandbytes -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu +from tests.helpers import TRUE_FALSE, describe_dtype, get_available_devices, id_formatter, is_supported_on_hpu opcheck = torch.library.opcheck @@ -266,6 +266,80 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize)) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("requires_grad", TRUE_FALSE, ids=id_formatter("requires_grad")) + @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) + @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) + def test_gemm_4bit(self, device, dtype, quant_type, compress_statistics, has_bias, storage_dtype, requires_grad): + if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): + pytest.skip("This configuration is not supported on HPU.") + + N, K, blocksize = 64, 64, 64 + + A = torch.randn(2, 2, K, dtype=dtype, device=device, requires_grad=requires_grad) + B = torch.randn(N, K, dtype=dtype, device=device) + B_q, qs = bitsandbytes.functional.quantize_4bit( + B, + blocksize=blocksize, + quant_type=quant_type, + compress_statistics=compress_statistics, + quant_storage=storage_dtype, + ) + bias = torch.randn(N, dtype=dtype, device=device) if has_bias else None + + if compress_statistics: + out = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B_q, + list(B.shape), + qs.state2.absmax, + blocksize, + quant_type, + bias=bias, + absmax_8bit=qs.absmax, + absmax_code=qs.state2.code, + absmax_offset=qs.offset, + ) + else: + out = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B_q, + list(B.shape), + qs.absmax, + blocksize, + quant_type, + bias=bias, + ) + + assert out.shape == (2, 2, N) + assert out.dtype == dtype + assert out.device.type == A.device.type + assert out.isreal().all() + + # TODO: remove detach when register_autograd is added for gemm_4bit. + # opcheck requires no autograd; detach A to skip the registration check. + A_op = A.detach() + if compress_statistics: + opcheck( + torch.ops.bitsandbytes.gemm_4bit.default, + (A_op, B_q, list(B.shape), qs.state2.absmax, blocksize, quant_type), + kwargs={ + "bias": bias, + "absmax_8bit": qs.absmax, + "absmax_code": qs.state2.code, + "absmax_offset": qs.offset, + }, + ) + else: + opcheck( + torch.ops.bitsandbytes.gemm_4bit.default, + (A_op, B_q, list(B.shape), qs.absmax, blocksize, quant_type), + kwargs={"bias": bias}, + ) + class TestNonContiguousInputs: """Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly."""