From cb28bf86c2a8060c1d4de216743fd75f91d87f82 Mon Sep 17 00:00:00 2001 From: kgrama <3001197+kgrama@users.noreply.github.com> Date: Sun, 10 May 2026 17:02:44 +0100 Subject: [PATCH] pdf initial commit --- bitsandbytes/_pbf4.py | 145 ++++++++++++++++++++++ bitsandbytes/_pbf8.py | 59 +++++++++ bitsandbytes/autograd/_functions.py | 10 +- bitsandbytes/backends/cpu/ops.py | 10 +- bitsandbytes/backends/cuda/ops.py | 16 ++- bitsandbytes/backends/default/ops.py | 41 ++++++- bitsandbytes/backends/triton/ops.py | 1 + bitsandbytes/functional.py | 21 +++- csrc/common.h | 1 + csrc/kernels.cu | 173 +++++++++++++++++++++++++++ csrc/ops.cu | 19 +++ csrc/pythonInterface.cpp | 68 +++++++++++ tests/test_pbf4.py | 163 +++++++++++++++++++++++++ 13 files changed, 713 insertions(+), 14 deletions(-) create mode 100644 bitsandbytes/_pbf4.py create mode 100644 bitsandbytes/_pbf8.py create mode 100644 tests/test_pbf4.py diff --git a/bitsandbytes/_pbf4.py b/bitsandbytes/_pbf4.py new file mode 100644 index 000000000..6b8ee92de --- /dev/null +++ b/bitsandbytes/_pbf4.py @@ -0,0 +1,145 @@ +"""PBF4 — peace-quant's PBF-MX, derived dynamically from the PBF8 spine. + +The 4-bit LUT is generated by sampling 8 magnitudes from the standard PBF8 +spine (``_pbf8``) at every-other level — same construction as peace-quant's +[mx_pbf_lut](crates/pbf8/src/mx4.rs#L25). Step ratio between levels is +``exp(2·LEVEL_LOG_STEP) = 2^(3/8) ≈ 1.297`` (~30%/level). Magnitudes are +normalised so the top entry is exactly 1.0, then mirrored into NF4-style +asymmetric layout (7 neg + 0 + 8 pos = 16 unique values). + +Same LUT for every tensor — no per-tensor calibration. The LUT lives in +``QuantState.code`` so CUDA's LUT-generic ``kgemm_4bit_inference_naive`` kernel +runs unchanged. +""" + +from typing import Optional + +import torch + +from . import _pbf8 + +# 4-bit NF4-layout: 7 neg + 0 + 8 pos = 16 unique LUT entries. +NUM_LUT_MAGS: int = 8 + +# Sample 8 magnitudes from the PBF8 spine at every-other level (stride 2), +# starting at level 2 so the highest entry comes from PBF8 mag-index 16 +# (= ``RING_POW[1]``). The first sampled magnitude is at level 2 (≈ 1.297 +# in raw PBF8 space); after normalising the top to 1.0 we get peace-quant's +# pbf_mx_lut shape: ``[0.162, 0.21, 0.273, 0.354, 0.458, 0.594, 0.771, 1.0]``. +_PBF4_PBF8_START_LEVEL: int = 2 + + +def _build_pbf_mx_lut() -> torch.Tensor: + """Generate the PBF4 LUT by sub-sampling the PBF8 spine.""" + raw_mags = _pbf8.sample_every_other_level(NUM_LUT_MAGS, start_level=_PBF4_PBF8_START_LEVEL) + top = raw_mags[-1] + mags = [m / top for m in raw_mags] + values = [-m for m in reversed(mags[1:])] + [0.0] + mags + return torch.tensor(values, dtype=torch.float32) + + +# PBF4 LUT — derived from PBF8 at module load. Same LUT for every tensor. +PBF_MX_LUT: torch.Tensor = _build_pbf_mx_lut() + + +# --------------------------------------------------------------------------- +# Block normalisation (per-block absmax → [-1, 1]). +# --------------------------------------------------------------------------- + + +def _block_normalise(A: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor, int]: + n = A.numel() + full_blocks = n // blocksize + rem = n % blocksize + blocks = full_blocks + (1 if rem else 0) + A_flat = A.reshape(n).float() + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + parts = [] + if full_blocks: + full = A_flat[: full_blocks * blocksize].reshape(full_blocks, blocksize) + absmax[:full_blocks] = full.abs().max(dim=-1).values + denom = absmax[:full_blocks].clamp(min=1e-30).view(-1, 1) + parts.append((full / denom).clamp(-1, 1).reshape(-1)) + if rem: + rem_part = A_flat[-rem:] + absmax[-1] = rem_part.abs().max() + denom = absmax[-1].clamp(min=1e-30) + parts.append((rem_part / denom).clamp(-1, 1)) + scaled = torch.cat(parts, dim=0) if parts else A_flat + return scaled, absmax, full_blocks + + +# --------------------------------------------------------------------------- +# Quantize / dequantize using the fixed PBF-MX LUT. +# --------------------------------------------------------------------------- + + +def quantize_pbf4_blockwise( + A: torch.Tensor, + blocksize: int, + quant_storage: torch.dtype, + lut: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize ``A`` to PBF4. Returns ``(packed_codes, absmax, lut)``. + + LUT defaults to ``PBF_MX_LUT`` (top = 1.0), so the bnb kernel formula + ``quant_map[byte] * absmax`` reproduces block-max directly. + """ + if lut is None: + lut = PBF_MX_LUT + lut = lut.to(A.device) + + scaled, absmax, _ = _block_normalise(A, blocksize) + diff = (scaled.unsqueeze(-1) - lut.to(scaled.dtype)).abs() + codes = diff.argmin(dim=-1).to(torch.int64) + + if codes.numel() % 2 != 0: + codes = torch.cat([codes, codes.new_zeros(1)]) + packed = ((codes[::2] << 4) | codes[1::2]).to(torch.uint8).unsqueeze(1) + + if quant_storage != torch.uint8: + packed = packed.squeeze().view(quant_storage).unsqueeze(1) + return packed, absmax, lut + + +def dequantize_pbf4_blockwise( + A: torch.Tensor, + absmax: torch.Tensor, + lut: torch.Tensor, + blocksize: int, + shape, + dtype: torch.dtype, +) -> torch.Tensor: + """Inverse of ``quantize_pbf4_blockwise``: ``lut[byte] * absmax``.""" + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + A = A.reshape(-1) + + out_codes = torch.empty(A.numel() * 2, dtype=torch.int64, device=A.device) + out_codes[::2] = (A >> 4).long() + out_codes[1::2] = (A & 0xF).long() + + n = 1 + for s in shape: + n *= s + out_codes = out_codes[:n] + + lut_dt = lut.to(dtype).to(A.device) + decoded = lut_dt[out_codes] + + blocks = (n + blocksize - 1) // blocksize + rem = n % blocksize + has_rem = rem > 0 + out = torch.empty((n,), dtype=dtype, device=A.device) + if has_rem: + full_n = (blocks - 1) * blocksize + if full_n: + out[:full_n] = (decoded[:full_n].view(-1, blocksize) * absmax[: blocks - 1].to(dtype).view(-1, 1)).reshape( + -1 + ) + out[full_n:] = decoded[full_n:] * absmax[-1].to(dtype) + else: + out = (decoded.view(blocks, blocksize) * absmax.to(dtype).view(-1, 1)).reshape(-1) + + return out.reshape(shape) diff --git a/bitsandbytes/_pbf8.py b/bitsandbytes/_pbf8.py new file mode 100644 index 000000000..df7558b7e --- /dev/null +++ b/bitsandbytes/_pbf8.py @@ -0,0 +1,59 @@ +"""PBF8 — peace-quant's standard 8-bit log-polar spine. + +Mirrors the constants from +[crates/pbf8/src/format.rs](crates/pbf8/src/format.rs): + +- ``BASE = φ + π`` (≈ 4.7596) — the irrational base anchoring all rings. +- ``RING_POW`` — 8-element ring spine, each ring 8x the previous (R=8 spacing). + Spans roughly ``5.8e-4 .. 1100``. +- ``LEVEL_LOG_STEP = ln(8) / 16`` — log-step per level inside a ring. +- 8 rings * 16 levels = 128 magnitudes per sign side. With the byte-0/byte-255 + sentinels, total 256 codes per byte. + +For a magnitude index ``mag ∈ [0, 127]``: +``decode_mag(mag) = (BASE/8192) · exp(mag · LEVEL_LOG_STEP) = RING_POW[mag>>4] · exp((mag & 0xF) · LEVEL_LOG_STEP)``. + +This module exposes the spine to higher-level formats. PBF4 (``_pbf4``) builds +its 4-bit LUT by sampling 8 magnitudes from this spine at every-other level. +""" + +import math + +PHI: float = 1.618_034 +BASE: float = PHI + math.pi # ≈ 4.7595918 + +# Ring spine — 8 rings at R=8 spacing, span ≈ 5.8e-4 .. 1100. +RING_POW: tuple[float, ...] = ( + BASE / 8192.0, + BASE / 1024.0, + BASE / 128.0, + BASE / 16.0, + BASE / 2.0, + 4.0 * BASE, + 32.0 * BASE, + 256.0 * BASE, +) + +LEVEL_LOG_STEP: float = math.log(8.0) / 16.0 # = 3·ln(2)/16 + +# 128 magnitudes (mag indices 0..127). Byte 0 is the zero sentinel, byte 255 +# is the saturation sentinel; nonzero magnitudes occupy bytes 1..254. +N_MAGS: int = 128 + + +def decode_mag(mag: int) -> float: + """Decode a magnitude index ``mag ∈ [0, N_MAGS)`` to its positive fp32 value.""" + if mag <= 0: + return RING_POW[0] + if mag >= N_MAGS: + mag = N_MAGS - 1 + return (BASE / 8192.0) * math.exp(mag * LEVEL_LOG_STEP) + + +def sample_every_other_level(n: int = 8, start_level: int = 0) -> list[float]: + """Return ``n`` magnitudes by sampling the PBF8 spine at every-other level. + + Used to derive lower-bit-depth LUTs (PBF4 takes ``n=8``). Sampling stride + is 2 levels = ``2 · LEVEL_LOG_STEP`` (~30% step ratio). + """ + return [decode_mag(start_level + 2 * k) for k in range(n)] diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 95a7d9090..4def1bbb1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -382,8 +382,16 @@ def matmul_4bit( bias: Optional[torch.Tensor] = None, ): assert quant_state is not None + # PBF4 stores a calibrated 16-entry LUT in quant_state.code; CUDA's + # kgemm_4bit_inference_naive loads ``datatype`` into shared memory and is fully + # LUT-generic, so pbf4 piggybacks the existing gemv_4bit fast path with no kernel + # changes. The CPU fused-gemm path (``packing_format_for_cpu``) is opt-in and + # uses a code-tensor heuristic that doesn't account for arbitrary calibrated + # LUTs, so we keep pbf4 off that specific lane. + is_pbf4 = getattr(quant_state, "quant_type", None) == "pbf4" + if A.device.type == "cpu": - if getattr(quant_state, "packing_format_for_cpu", False): + if getattr(quant_state, "packing_format_for_cpu", False) and not is_pbf4: out = F.gemv_4bit(A, B, out, state=quant_state) if bias is not None: out += bias diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 6b82c2421..33af639d2 100755 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -147,7 +147,10 @@ def _( dtype: torch.dtype, ) -> torch.Tensor: torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + quant_type in ("nf4", "fp4", "pbf4"), + lambda: f"quant_type must be nf4, fp4, or pbf4, got {quant_type}", + ) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", @@ -160,7 +163,10 @@ def _( # Odd shape is not supported by this kernel; fallback to generic implementation shape_fallback = shape[-1] % 2 != 0 - if avx512_fallback or shape_fallback: + # No native pbf4 CPU kernel — route to the LUT-driven default impl. + pbf4_fallback = quant_type == "pbf4" + + if avx512_fallback or shape_fallback or pbf4_fallback: from ..default.ops import _dequantize_4bit_impl return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 409e0252d..074cc0afb 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -313,7 +313,7 @@ def _( A = A.contiguous() torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) - torch._check(quant_type in ["fp4", "nf4"]) + torch._check(quant_type in ["fp4", "nf4", "pbf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", @@ -337,16 +337,22 @@ def _( if A.dtype == torch.bfloat16: if quant_type == "fp4": lib.cquantize_blockwise_bf16_fp4(*args) + elif quant_type == "pbf4": + lib.cquantize_blockwise_bf16_pbf4(*args) else: lib.cquantize_blockwise_bf16_nf4(*args) elif A.dtype == torch.float16: if quant_type == "fp4": lib.cquantize_blockwise_fp16_fp4(*args) + elif quant_type == "pbf4": + lib.cquantize_blockwise_fp16_pbf4(*args) else: lib.cquantize_blockwise_fp16_nf4(*args) elif A.dtype == torch.float32: if quant_type == "fp4": lib.cquantize_blockwise_fp32_fp4(*args) + elif quant_type == "pbf4": + lib.cquantize_blockwise_fp32_pbf4(*args) else: lib.cquantize_blockwise_fp32_nf4(*args) @@ -393,7 +399,7 @@ def _dequantize_4bit_impl( A = A.contiguous() torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) - torch._check(quant_type in ["fp4", "nf4"]) + torch._check(quant_type in ["fp4", "nf4", "pbf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", @@ -413,16 +419,22 @@ def _dequantize_4bit_impl( if out.dtype == torch.bfloat16: if quant_type == "fp4": lib.cdequantize_blockwise_bf16_fp4(*args) + elif quant_type == "pbf4": + lib.cdequantize_blockwise_bf16_pbf4(*args) else: lib.cdequantize_blockwise_bf16_nf4(*args) elif out.dtype == torch.float16: if quant_type == "fp4": lib.cdequantize_blockwise_fp16_fp4(*args) + elif quant_type == "pbf4": + lib.cdequantize_blockwise_fp16_pbf4(*args) else: lib.cdequantize_blockwise_fp16_nf4(*args) elif out.dtype == torch.float32: if quant_type == "fp4": lib.cdequantize_blockwise_fp32_fp4(*args) + elif quant_type == "pbf4": + lib.cdequantize_blockwise_fp32_pbf4(*args) else: lib.cdequantize_blockwise_fp32_nf4(*args) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 6f5eecdf2..769b3ab83 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -216,17 +216,28 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, return out -@register_kernel("bitsandbytes::quantize_4bit", "default") -def _( +def _quantize_4bit_impl( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + quant_type in ("nf4", "fp4", "pbf4"), + lambda: f"quant_type must be nf4, fp4, or pbf4, got {quant_type}", + ) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", ) + if quant_type == "pbf4": + # PBF4 uses a fixed LUT derived from the PBF8 spine — same shape every + # call, so the op-level path can produce correct (packed, absmax) + # without any extra metadata. Python fallback for non-CUDA backends. + from ..._pbf4 import PBF_MX_LUT, quantize_pbf4_blockwise + + packed, absmax, _ = quantize_pbf4_blockwise(A, blocksize, quant_storage, lut=PBF_MX_LUT) + return packed, absmax + n = A.numel() full_blocks = n // blocksize rem = n % blocksize @@ -262,6 +273,13 @@ def _( return packed, absmax.float() +@register_kernel("bitsandbytes::quantize_4bit", "default") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage) + + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -270,6 +288,11 @@ def _dequantize_4bit_impl( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: + if quant_type == "pbf4": + from ..._pbf4 import PBF_MX_LUT, dequantize_pbf4_blockwise + + return dequantize_pbf4_blockwise(A, absmax, PBF_MX_LUT, blocksize, shape, dtype) + # Enable non uint8 dtype if A.dtype != torch.uint8: A = A.view(torch.uint8) @@ -318,7 +341,10 @@ def _( dtype: torch.dtype, ) -> torch.Tensor: torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + quant_type in ("nf4", "fp4", "pbf4"), + lambda: f"quant_type must be nf4, fp4, or pbf4, got {quant_type}", + ) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", @@ -336,8 +362,11 @@ def _( code: torch.Tensor, blocksize: int, ) -> torch.Tensor: - # Applied from dequantize_4bit - quant_type = "fp4" if code[1] > 0 else "nf4" + # Recover quant_type from the LUT's distinctive ``code[1]``: + # fp4 ≈ +0.0052 (small positive); nf4 ≈ -0.696. (PBF4 has per-tensor + # calibrated LUTs that can't reach this op-level path — see + # ``functional.dequantize_4bit`` which short-circuits PBF4 directly.) + quant_type = "fp4" if float(code[1]) > 0 else "nf4" B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype) return torch.nn.functional.linear( diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 6b1a2904b..18db5b3f3 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -115,6 +115,7 @@ def dequantize_4bit( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) + # torch._check( # A.dtype == torch.uint8, # lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}", diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0165a1288..7825c3166 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -419,7 +419,7 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: class QuantState: """container for quantization state components to work with Params4bit and similar classes""" - valid_quant_types = ("fp4", "nf4") + valid_quant_types = ("fp4", "nf4", "pbf4") valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] valid_qs_keys = [ "absmax", @@ -809,6 +809,13 @@ def get_4bit_type(typename, device=None, blocksize=64): # # All values are normalized to [-1, 1] after construction (see end of function). data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] + elif typename == "pbf4": + # PBF4 (peace-quant PBF-MX): 8 magnitudes sampled every-other level of the + # standard PBF8 spine. The fixed LUT lives in ``bitsandbytes._pbf4.PBF_MX_LUT``; + # it is dynamically derived from the PBF8 constants in ``_pbf8`` at module load. + from bitsandbytes._pbf4 import PBF_MX_LUT + + data = PBF_MX_LUT.tolist() elif typename == "int4": data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] elif typename == "af4": @@ -913,8 +920,16 @@ def quantize_4bit( quant_type, quant_storage, ) - - code = get_4bit_type(quant_type, device=A.device) + if quant_type == "pbf4": + # PBF4 uses a fixed LUT derived from the PBF8 spine (`_pbf4.PBF_MX_LUT`). + # The op only returns (packed, absmax); attach the canonical LUT to + # QuantState.code here so CUDA's LUT-generic gemv_4bit and the Triton/ + # default fallback paths all see the same table. + from bitsandbytes._pbf4 import PBF_MX_LUT + + code = PBF_MX_LUT.to(A.device) + else: + code = get_4bit_type(quant_type, device=A.device) if compress_statistics: offset = _absmax.mean() diff --git a/csrc/common.h b/csrc/common.h index 1496c0bc3..3f528172a 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -4,4 +4,5 @@ typedef enum DataType_t { General8bit = 0, FP4 = 1, NF4 = 2, + PBF4 = 3, } DataType_t; diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0d313c8d7..e1ab40ad9 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -42,6 +42,59 @@ __device__ static float nf4_dequantization_lut[16] = { 1.0f // 0b1111 }; +// PBF4 (peace-quant PBF-MX) -- 8 magnitudes sampled at every-other level of the +// PBF8 standard ring, normalised so the top entry is 1.0, then mirrored NF4-style +// (7 negatives + 0 + 8 positives = 16 unique entries). +// +// Construction (mirrors bitsandbytes/_pbf4.py::_build_pbf_mx_lut, which derives +// these from the PBF8 spine constants in bitsandbytes/_pbf8.py): +// +// raw_mags[k] = (BASE/8192) * exp((2*k + START_LEVEL) * LEVEL_LOG_STEP) +// where BASE = phi + pi (peace-quant irrational base) +// LEVEL_LOG_STEP = ln(8) / 16 +// START_LEVEL = 2 (every-other level, base-anchored) +// k = 0 .. 7 (8 magnitudes) +// +// Equivalently, mag[k] / mag[7] = 2^(3*(k - 7)/8), i.e. each step is a factor +// of 2^(3/8) ~= 1.297 (~30%/level). After dividing by mag[7] the entries are: +// +// k=0: 2^(-21/8) = 0.16210494... +// k=1: 2^(-18/8) = 0.21022410... +// k=2: 2^(-15/8) = 0.27262693... +// k=3: 2^(-12/8) = 0.35355339... +// k=4: 2^(-9/8) = 0.45850202... +// k=5: 2^(-6/8) = 0.59460356... +// k=6: 2^(-3/8) = 0.77110541... +// k=7: 2^( 0) = 1.0 +// +// Layout: byte = sign_bit << 3 | mag_index, mirroring NF4. Byte 0b0111 (= 7) is +// the unique zero; byte 0b1000 (= 8) is the smallest positive magnitude (NOT a +// duplicate zero). Sign bit is bit 3 (LSB of the upper nibble). +// +// To regenerate: change PBF8's BASE, LEVEL_LOG_STEP, or START_LEVEL in +// bitsandbytes/_pbf8.py / bitsandbytes/_pbf4.py and copy the resulting Python +// PBF_MX_LUT.tolist() into this table; the values must stay in lock-step +// because the device-side LUT cannot be loaded from the Python module at +// kernel-launch time. +__device__ static float pbf4_dequantization_lut[16] = { + -1.0f, // 0b0000 -2^( 0) + -0.7711054f, // 0b0001 -2^(-3/8) + -0.5946036f, // 0b0010 -2^(-6/8) + -0.4585020f, // 0b0011 -2^(-9/8) + -0.3535534f, // 0b0100 -2^(-12/8) + -0.2726269f, // 0b0101 -2^(-15/8) + -0.2102241f, // 0b0110 -2^(-18/8) + 0.0f, // 0b0111 zero + 0.1621049f, // 0b1000 2^(-21/8) (smallest positive) + 0.2102241f, // 0b1001 2^(-18/8) + 0.2726269f, // 0b1010 2^(-15/8) + 0.3535534f, // 0b1011 2^(-12/8) + 0.4585020f, // 0b1100 2^(-9/8) + 0.5946036f, // 0b1101 2^(-6/8) + 0.7711054f, // 0b1110 2^(-3/8) + 1.0f // 0b1111 2^( 0) (largest positive) +}; + // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda // HIP has native atomicMax for float; CUDA needs a CAS loop #if !BNB_HIP @@ -107,6 +160,69 @@ __device__ unsigned char dQuantizeFP4(float x) { __device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } +__device__ __forceinline__ float dDequantizePBF4(unsigned char val) { return pbf4_dequantization_lut[val & 0x0F]; } + +__device__ unsigned char dQuantizePBF4(float x) { + // Encode an fp32 in [-1.0, 1.0] to a 4-bit PBF4 code. Uses a balanced + // binary-search tree (3 comparisons in the worst case) over the PBF4 + // positive LUT entries; the negative side mirrors via the sign bit + // (bit 3, set by `sign` below). Because PBF4 is log-spaced the optimal + // split points are the GEOMETRIC midpoints between adjacent LUT entries, + // NOT arithmetic midpoints. The split between 0 and the smallest positive + // entry is a special case -- there is no geometric mean of 0 and a + // nonzero value, so we use the linear midpoint there (= L0 / 2). + // + // Threshold derivation -- L_k are the positive LUT magnitudes from + // pbf4_dequantization_lut[8 + k] for k = 0..7 + // (L = [0.1621, 0.2102, 0.2726, 0.3536, 0.4585, 0.5946, 0.7711, 1.0]): + // + // T0 = L0 / 2 = 0.08105 (zero -> mag 0 split) + // T_k = sqrt(L_{k-1} * L_k) for k = 1..7 + // T1 = sqrt(0.1621 * 0.2102) = 0.18468 + // T2 = sqrt(0.2102 * 0.2726) = 0.23947 + // T3 = sqrt(0.2726 * 0.3536) = 0.31052 + // T4 = sqrt(0.3536 * 0.4585) = 0.40262 + // T5 = sqrt(0.4585 * 0.5946) = 0.52215 + // T6 = sqrt(0.5946 * 0.7711) = 0.67710 + // T7 = sqrt(0.7711 * 1.0000) = 0.87813 + // + // To regenerate when the LUT changes: + // + // >>> import math + // >>> L = [v for v in PBF_MX_LUT.tolist() if v > 0] + // >>> [L[0] / 2.0] + [math.sqrt(L[k-1] * L[k]) for k in range(1, 8)] + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + // Upper half: |x| > T4 -> mags 4..7 (LUT entries 0.4585 .. 1.0) + if (x > 0.40262f) { + if (x > 0.67710f) { + if (x > 0.87813f) + return 0b1111 + sign; // mag 7, |L| = 1.0 + else + return 0b1110 + sign; // mag 6, |L| = 0.7711054 + } else if (x > 0.52215f) + return 0b1101 + sign; // mag 5, |L| = 0.5946036 + else + return 0b1100 + sign; // mag 4, |L| = 0.4585020 + } + // Lower half: |x| in (T1, T4] -> mags 1..3 + else if (x > 0.18468f) { + if (x > 0.31052f) + return 0b1011 + sign; // mag 3, |L| = 0.3535534 + else if (x > 0.23947f) + return 0b1010 + sign; // mag 2, |L| = 0.2726269 + else + return 0b1001 + sign; // mag 1, |L| = 0.2102241 + } + // |x| in (T0, T1] -> mag 0 (smallest positive, |L| = 0.1621049) + else if (x > 0.08105f) + return 0b1000 + sign; + // |x| <= T0 -> exact zero. The unique zero is byte 0b0111; the negative + // side does NOT have a duplicate -0 (NF4-style asymmetric layout). + else + return 0b0111; +} + __device__ unsigned char dQuantizeNF4(float x) { // the values for this tree was generated by test_normal_map_tree @@ -365,6 +481,13 @@ __global__ void kQuantizeBlockwise( qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; + case PBF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizePBF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizePBF4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; } __syncthreads(); @@ -456,6 +579,13 @@ __global__ void kQuantizeBlockwiseSmall( qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; + case PBF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizePBF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizePBF4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; } __syncthreads(); @@ -521,6 +651,13 @@ __global__ void vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; } break; + case PBF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizePBF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizePBF4(qvals[j] & 0x0F) * local_abs_max; + } + break; } __syncthreads(); @@ -1722,6 +1859,13 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, PBF4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) @@ -1744,6 +1888,13 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, PBF4) MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 1, General8bit) @@ -1767,6 +1918,13 @@ MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, PBF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, PBF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, PBF4) // Template instantiations for kQuantizeBlockwiseSmall (4-bit only) #define MAKE_kQuantizeBlockwiseSmall(dtype, qblock_size, data_type_name) \ @@ -1782,6 +1940,9 @@ MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, FP4) MAKE_kQuantizeBlockwiseSmall(half, 32, NF4) MAKE_kQuantizeBlockwiseSmall(float, 32, NF4) MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, NF4) +MAKE_kQuantizeBlockwiseSmall(half, 32, PBF4) +MAKE_kQuantizeBlockwiseSmall(float, 32, PBF4) +MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, PBF4) // QBLOCK_SIZE=64 instantiations (blocksize=64, 4-bit) MAKE_kQuantizeBlockwiseSmall(half, 64, FP4) @@ -1790,6 +1951,9 @@ MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, FP4) MAKE_kQuantizeBlockwiseSmall(half, 64, NF4) MAKE_kQuantizeBlockwiseSmall(float, 64, NF4) MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, NF4) +MAKE_kQuantizeBlockwiseSmall(half, 64, PBF4) +MAKE_kQuantizeBlockwiseSmall(float, 64, PBF4) +MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, PBF4) template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n @@ -1800,6 +1964,9 @@ template __global__ void kDequantizeBlockwise( template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n ); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n ); @@ -1809,6 +1976,9 @@ template __global__ void kDequantizeBlockwise( template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n ); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n ); @@ -1818,6 +1988,9 @@ template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n ); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit2StateBlockwise( \ diff --git a/csrc/ops.cu b/csrc/ops.cu index 16eed4e81..112b9bbf2 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -506,6 +506,9 @@ template void quantizeBlockwise( template void quantizeBlockwise( float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); template void quantizeBlockwise( float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); @@ -518,6 +521,9 @@ template void quantizeBlockwise( template void quantizeBlockwise( float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); template void quantizeBlockwise( float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n @@ -534,6 +540,10 @@ template void quantizeBlockwise( float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream @@ -544,6 +554,9 @@ template void dequantizeBlockwise( template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream ); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream +); template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream ); @@ -553,6 +566,9 @@ template void dequantizeBlockwise( template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream ); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream +); template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream ); @@ -562,6 +578,9 @@ template void dequantizeBlockwise( template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream ); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream +); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit( \ diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 214c2a2d8..24a165e39 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -142,6 +142,10 @@ void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned ch quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); } +void quantizeBlockwise_fp16_pbf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + void quantizeBlockwise_bf16( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { @@ -160,6 +164,12 @@ void quantizeBlockwise_bf16_nf4( quantizeBlockwise<__nv_bfloat16, 0, NF4>(nullptr, A, absmax, out, nullptr, 0, blocksize, n); } +void quantizeBlockwise_bf16_pbf4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, PBF4>(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise(code, A, absmax, out, nullptr, 0, blocksize, n); } @@ -172,6 +182,10 @@ void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned c quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); } +void quantizeBlockwise_fp32_pbf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + void dequantizeBlockwise_fp16( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream ) { @@ -190,6 +204,12 @@ void dequantizeBlockwise_fp16_nf4( dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_fp16_pbf4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + void dequantizeBlockwise_fp32( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { @@ -208,6 +228,12 @@ void dequantizeBlockwise_fp32_nf4( dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_fp32_pbf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + void dequantizeBlockwise_bf16( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { @@ -226,6 +252,12 @@ void dequantizeBlockwise_bf16_nf4( dequantizeBlockwise<__nv_bfloat16, NF4>(nullptr, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_pbf4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, PBF4>(nullptr, A, absmax, out, blocksize, n, stream); +} + int igemmlt_32( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream @@ -367,6 +399,12 @@ void cdequantize_blockwise_fp16_nf4( dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } +void cdequantize_blockwise_fp16_pbf4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16_pbf4(code, A, absmax, out, blocksize, n, stream); +} + void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } @@ -379,6 +417,12 @@ void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } +void cquantize_blockwise_fp16_pbf4( + float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_fp16_pbf4(code, A, absmax, out, blocksize, n); +} + void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } @@ -395,6 +439,12 @@ void cquantize_blockwise_fp32_nf4( quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } +void cquantize_blockwise_fp32_pbf4( + float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_fp32_pbf4(code, A, absmax, out, blocksize, n); +} + void cdequantize_blockwise_fp32( float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream ) { @@ -413,6 +463,12 @@ void cdequantize_blockwise_fp32_nf4( dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } +void cdequantize_blockwise_fp32_pbf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32_pbf4(code, A, absmax, out, blocksize, n, stream); +} + void cquantize_blockwise_bf16( float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n ) { @@ -431,6 +487,12 @@ void cquantize_blockwise_bf16_nf4( quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } +void cquantize_blockwise_bf16_pbf4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16_pbf4(code, A, absmax, out, blocksize, n); +} + void cdequantize_blockwise_bf16( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream ) { @@ -449,6 +511,12 @@ void cdequantize_blockwise_bf16_nf4( dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } +void cdequantize_blockwise_bf16_pbf4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16_pbf4(code, A, absmax, out, blocksize, n, stream); +} + #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits( \ gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ diff --git a/tests/test_pbf4.py b/tests/test_pbf4.py new file mode 100644 index 000000000..54c610eee --- /dev/null +++ b/tests/test_pbf4.py @@ -0,0 +1,163 @@ +"""Tests for the PBF4 (peace-quant PBF-MX) implementation. + +PBF4 here is a fixed 4-bit LUT derived dynamically from the PBF8 spine +(``_pbf8``) at every-other level — same construction as peace-quant's +``mx_pbf_lut``. No per-tensor calibration. The same LUT lives in +``QuantState.code`` for every tensor; CUDA's ``kgemm_4bit_inference_naive`` +reads it as the ``datatype`` arg into ``__shared__ T quant_map[16]``. +""" + +import math + +import pytest +import torch + +from bitsandbytes import _pbf8 +from bitsandbytes._pbf4 import ( + NUM_LUT_MAGS, + PBF_MX_LUT, +) +from bitsandbytes.backends.utils import CODE +import bitsandbytes.functional as F +from bitsandbytes.nn import Linear4bit + + +def test_no_default_lut_in_code_dict(): + assert "pbf4" not in CODE + + +def test_pbf_mx_lut_layout(): + listed = PBF_MX_LUT.tolist() + assert len(listed) == 16 + assert listed == sorted(listed) + assert listed[0] == pytest.approx(-1.0) + assert listed[-1] == pytest.approx(1.0) + assert sum(1 for v in listed if v == 0.0) == 1 + # NF4-style asymmetric: 7 negative + 0 + 8 positive. + assert sum(1 for v in listed if v < 0) == 7 + assert sum(1 for v in listed if v > 0) == 8 + + +def test_pbf_mx_lut_log_step(): + # Step ratio between adjacent positive entries should be exp(2·LEVEL_LOG_STEP) + # = 2^(3/8) ≈ 1.297 — the every-other-level sampling of PBF8. + pos = [v for v in PBF_MX_LUT.tolist() if v > 0] + expected_ratio = math.exp(2.0 * _pbf8.LEVEL_LOG_STEP) + for k in range(len(pos) - 1): + ratio = pos[k + 1] / pos[k] + assert ratio == pytest.approx(expected_ratio, rel=1e-5), ( + f"pos[{k + 1}]/pos[{k}] = {ratio}, expected {expected_ratio}" + ) + + +def test_pbf_mx_lut_derived_from_pbf8(): + raw = _pbf8.sample_every_other_level(NUM_LUT_MAGS, start_level=2) + top = raw[-1] + expected_normalised = [m / top for m in raw] + # Nonzero positives in the LUT match the 8 normalised mags from PBF8. + pos_nonzero = [v for v in PBF_MX_LUT.tolist() if v > 0] + assert len(pos_nonzero) == NUM_LUT_MAGS + for actual, expected in zip(pos_nonzero, expected_normalised): + assert actual == pytest.approx(expected, rel=1e-6) + + +@pytest.mark.parametrize("blocksize", [32, 64, 128, 256]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_pbf4_roundtrip_cpu(blocksize, dtype): + torch.manual_seed(0) + A = torch.randn(256, 256, device="cpu", dtype=dtype) + + qa, state = F.quantize_4bit(A, blocksize=blocksize, quant_type="pbf4") + A_dq = F.dequantize_4bit(qa, state, blocksize=blocksize, quant_type="pbf4") + + assert state.quant_type == "pbf4" + assert state.code.numel() == 16 + assert qa.dtype == torch.uint8 + assert A_dq.shape == A.shape + assert A_dq.dtype == dtype + assert torch.isfinite(A_dq).all() + + +def test_pbf4_state_code_is_the_fixed_lut(): + torch.manual_seed(0) + A = torch.randn(128, 128) + _, state = F.quantize_4bit(A, blocksize=64, quant_type="pbf4") + torch.testing.assert_close(state.code, PBF_MX_LUT.to(state.code.device), rtol=0, atol=0) + + # All tensors get the same LUT — this is a fixed format, not calibrated. + B = torch.randn(64, 64) * 100 + _, state2 = F.quantize_4bit(B, blocksize=64, quant_type="pbf4") + torch.testing.assert_close(state2.code, state.code, rtol=0, atol=0) + + +def test_pbf4_quant_state_serialization_cpu(): + torch.manual_seed(0) + A = torch.randn(128, 128, device="cpu", dtype=torch.float32) + + qa, state = F.quantize_4bit(A, blocksize=64, quant_type="pbf4") + packed = state.as_dict(packed=True) + assert any("bitsandbytes__pbf4" in k for k in packed.keys()) + + restored = F.QuantState.from_dict(state.as_dict(), device=torch.device("cpu")) + assert restored.quant_type == "pbf4" + torch.testing.assert_close(restored.code, state.code, rtol=0, atol=0) + torch.testing.assert_close(restored.absmax, state.absmax, rtol=0, atol=0) + + A_dq_orig = F.dequantize_4bit(qa, state, blocksize=64, quant_type="pbf4") + A_dq_restored = F.dequantize_4bit(qa, restored, blocksize=64, quant_type="pbf4") + torch.testing.assert_close(A_dq_orig, A_dq_restored) + + +def test_pbf4_compress_statistics_compatible(): + torch.manual_seed(0) + A = torch.randn(512, 512) + qa, state = F.quantize_4bit(A, blocksize=64, quant_type="pbf4", compress_statistics=True) + assert state.nested + A_dq = F.dequantize_4bit(qa, state, blocksize=64, quant_type="pbf4") + assert torch.isfinite(A_dq).all() + + +def test_pbf4_op_direct_roundtrip(): + # The op layer handles pbf4 directly via the fixed PBF_MX_LUT — + # no special functional-only path required. + torch.manual_seed(0) + A = torch.randn(64, 64) + qa, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, 64, "pbf4", torch.uint8) + out = torch.ops.bitsandbytes.dequantize_4bit.default(qa, absmax, 64, "pbf4", tuple(A.shape), A.dtype) + assert torch.isfinite(out).all() + err = (A - out).abs().mean().item() + assert err < 0.15 + + +def test_pbf4_finite_on_diverse_distributions(): + torch.manual_seed(0) + distros = { + "normal": torch.randn(4096), + "uniform": torch.rand(4096) * 2 - 1, + "log-uniform": torch.exp(torch.empty(4096).uniform_(-6, 0)) * torch.sign(torch.randn(4096)), + "cauchy": torch.distributions.Cauchy(0.0, 0.3).sample((4096,)).clamp_(-50, 50), + "student-t": torch.distributions.StudentT(3.0).sample((4096,)), + "mix-outliers": torch.cat([torch.randn(3900), torch.randn(196) * 50.0]), + } + for name, A in distros.items(): + if A.dim() == 1: + A = A.unsqueeze(0) + qa, st = F.quantize_4bit(A, blocksize=64, quant_type="pbf4") + dq = F.dequantize_4bit(qa, st, blocksize=64, quant_type="pbf4") + assert torch.isfinite(dq).all(), f"{name}: dequantised output has NaN/Inf" + assert (dq.abs() <= A.abs().max() * 1.05).all(), f"{name}: dequantised exceeds 1.05x tensor max" + + +def test_linear4bit_pbf4_forward_cpu(): + torch.manual_seed(0) + layer = Linear4bit(64, 32, bias=True, quant_type="pbf4", compress_statistics=False) + layer = layer.to("cpu") # triggers quantization + + state = layer.weight.quant_state + assert state.quant_type == "pbf4" + assert state.code.numel() == 16 + + x = torch.randn(8, 64, dtype=torch.float32) + y = layer(x) + assert y.shape == (8, 32) + assert torch.isfinite(y).all()