Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions bitsandbytes/_pbf4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""PBF4 — peace-quant's PBF-MX, derived dynamically from the PBF8 spine.

The 4-bit LUT is generated by sampling 8 magnitudes from the standard PBF8
spine (``_pbf8``) at every-other level — same construction as peace-quant's
[mx_pbf_lut](crates/pbf8/src/mx4.rs#L25). Step ratio between levels is
``exp(2·LEVEL_LOG_STEP) = 2^(3/8) ≈ 1.297`` (~30%/level). Magnitudes are
normalised so the top entry is exactly 1.0, then mirrored into NF4-style
asymmetric layout (7 neg + 0 + 8 pos = 16 unique values).

Same LUT for every tensor — no per-tensor calibration. The LUT lives in
``QuantState.code`` so CUDA's LUT-generic ``kgemm_4bit_inference_naive`` kernel
runs unchanged.
"""

from typing import Optional

import torch

from . import _pbf8

# 4-bit NF4-layout: 7 neg + 0 + 8 pos = 16 unique LUT entries.
NUM_LUT_MAGS: int = 8

# Sample 8 magnitudes from the PBF8 spine at every-other level (stride 2),
# starting at level 2 so the highest entry comes from PBF8 mag-index 16
# (= ``RING_POW[1]``). The first sampled magnitude is at level 2 (≈ 1.297
# in raw PBF8 space); after normalising the top to 1.0 we get peace-quant's
# pbf_mx_lut shape: ``[0.162, 0.21, 0.273, 0.354, 0.458, 0.594, 0.771, 1.0]``.
_PBF4_PBF8_START_LEVEL: int = 2


def _build_pbf_mx_lut() -> torch.Tensor:
"""Generate the PBF4 LUT by sub-sampling the PBF8 spine."""
raw_mags = _pbf8.sample_every_other_level(NUM_LUT_MAGS, start_level=_PBF4_PBF8_START_LEVEL)
top = raw_mags[-1]
mags = [m / top for m in raw_mags]
values = [-m for m in reversed(mags[1:])] + [0.0] + mags
return torch.tensor(values, dtype=torch.float32)


# PBF4 LUT — derived from PBF8 at module load. Same LUT for every tensor.
PBF_MX_LUT: torch.Tensor = _build_pbf_mx_lut()


# ---------------------------------------------------------------------------
# Block normalisation (per-block absmax → [-1, 1]).
# ---------------------------------------------------------------------------


def _block_normalise(A: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor, int]:
n = A.numel()
full_blocks = n // blocksize
rem = n % blocksize
blocks = full_blocks + (1 if rem else 0)
A_flat = A.reshape(n).float()
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)

parts = []
if full_blocks:
full = A_flat[: full_blocks * blocksize].reshape(full_blocks, blocksize)
absmax[:full_blocks] = full.abs().max(dim=-1).values
denom = absmax[:full_blocks].clamp(min=1e-30).view(-1, 1)
parts.append((full / denom).clamp(-1, 1).reshape(-1))
if rem:
rem_part = A_flat[-rem:]
absmax[-1] = rem_part.abs().max()
denom = absmax[-1].clamp(min=1e-30)
parts.append((rem_part / denom).clamp(-1, 1))
scaled = torch.cat(parts, dim=0) if parts else A_flat
return scaled, absmax, full_blocks


# ---------------------------------------------------------------------------
# Quantize / dequantize using the fixed PBF-MX LUT.
# ---------------------------------------------------------------------------


def quantize_pbf4_blockwise(
A: torch.Tensor,
blocksize: int,
quant_storage: torch.dtype,
lut: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Quantize ``A`` to PBF4. Returns ``(packed_codes, absmax, lut)``.

LUT defaults to ``PBF_MX_LUT`` (top = 1.0), so the bnb kernel formula
``quant_map[byte] * absmax`` reproduces block-max directly.
"""
if lut is None:
lut = PBF_MX_LUT
lut = lut.to(A.device)

scaled, absmax, _ = _block_normalise(A, blocksize)
diff = (scaled.unsqueeze(-1) - lut.to(scaled.dtype)).abs()
codes = diff.argmin(dim=-1).to(torch.int64)

if codes.numel() % 2 != 0:
codes = torch.cat([codes, codes.new_zeros(1)])
packed = ((codes[::2] << 4) | codes[1::2]).to(torch.uint8).unsqueeze(1)

if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax, lut


def dequantize_pbf4_blockwise(
A: torch.Tensor,
absmax: torch.Tensor,
lut: torch.Tensor,
blocksize: int,
shape,
dtype: torch.dtype,
) -> torch.Tensor:
"""Inverse of ``quantize_pbf4_blockwise``: ``lut[byte] * absmax``."""
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
A = A.reshape(-1)

out_codes = torch.empty(A.numel() * 2, dtype=torch.int64, device=A.device)
out_codes[::2] = (A >> 4).long()
out_codes[1::2] = (A & 0xF).long()

n = 1
for s in shape:
n *= s
out_codes = out_codes[:n]

lut_dt = lut.to(dtype).to(A.device)
decoded = lut_dt[out_codes]

blocks = (n + blocksize - 1) // blocksize
rem = n % blocksize
has_rem = rem > 0
out = torch.empty((n,), dtype=dtype, device=A.device)
if has_rem:
full_n = (blocks - 1) * blocksize
if full_n:
out[:full_n] = (decoded[:full_n].view(-1, blocksize) * absmax[: blocks - 1].to(dtype).view(-1, 1)).reshape(
-1
)
out[full_n:] = decoded[full_n:] * absmax[-1].to(dtype)
else:
out = (decoded.view(blocks, blocksize) * absmax.to(dtype).view(-1, 1)).reshape(-1)

return out.reshape(shape)
59 changes: 59 additions & 0 deletions bitsandbytes/_pbf8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""PBF8 — peace-quant's standard 8-bit log-polar spine.

Mirrors the constants from
[crates/pbf8/src/format.rs](crates/pbf8/src/format.rs):

- ``BASE = φ + π`` (≈ 4.7596) — the irrational base anchoring all rings.
- ``RING_POW`` — 8-element ring spine, each ring 8x the previous (R=8 spacing).
Spans roughly ``5.8e-4 .. 1100``.
- ``LEVEL_LOG_STEP = ln(8) / 16`` — log-step per level inside a ring.
- 8 rings * 16 levels = 128 magnitudes per sign side. With the byte-0/byte-255
sentinels, total 256 codes per byte.

For a magnitude index ``mag ∈ [0, 127]``:
``decode_mag(mag) = (BASE/8192) · exp(mag · LEVEL_LOG_STEP) = RING_POW[mag>>4] · exp((mag & 0xF) · LEVEL_LOG_STEP)``.

This module exposes the spine to higher-level formats. PBF4 (``_pbf4``) builds
its 4-bit LUT by sampling 8 magnitudes from this spine at every-other level.
"""

import math

PHI: float = 1.618_034
BASE: float = PHI + math.pi # ≈ 4.7595918

# Ring spine — 8 rings at R=8 spacing, span ≈ 5.8e-4 .. 1100.
RING_POW: tuple[float, ...] = (
BASE / 8192.0,
BASE / 1024.0,
BASE / 128.0,
BASE / 16.0,
BASE / 2.0,
4.0 * BASE,
32.0 * BASE,
256.0 * BASE,
)

LEVEL_LOG_STEP: float = math.log(8.0) / 16.0 # = 3·ln(2)/16

# 128 magnitudes (mag indices 0..127). Byte 0 is the zero sentinel, byte 255
# is the saturation sentinel; nonzero magnitudes occupy bytes 1..254.
N_MAGS: int = 128


def decode_mag(mag: int) -> float:
"""Decode a magnitude index ``mag ∈ [0, N_MAGS)`` to its positive fp32 value."""
if mag <= 0:
return RING_POW[0]
if mag >= N_MAGS:
mag = N_MAGS - 1
return (BASE / 8192.0) * math.exp(mag * LEVEL_LOG_STEP)


def sample_every_other_level(n: int = 8, start_level: int = 0) -> list[float]:
"""Return ``n`` magnitudes by sampling the PBF8 spine at every-other level.

Used to derive lower-bit-depth LUTs (PBF4 takes ``n=8``). Sampling stride
is 2 levels = ``2 · LEVEL_LOG_STEP`` (~30% step ratio).
"""
return [decode_mag(start_level + 2 * k) for k in range(n)]
10 changes: 9 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,16 @@ def matmul_4bit(
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
# PBF4 stores a calibrated 16-entry LUT in quant_state.code; CUDA's
# kgemm_4bit_inference_naive loads ``datatype`` into shared memory and is fully
# LUT-generic, so pbf4 piggybacks the existing gemv_4bit fast path with no kernel
# changes. The CPU fused-gemm path (``packing_format_for_cpu``) is opt-in and
# uses a code-tensor heuristic that doesn't account for arbitrary calibrated
# LUTs, so we keep pbf4 off that specific lane.
is_pbf4 = getattr(quant_state, "quant_type", None) == "pbf4"

if A.device.type == "cpu":
if getattr(quant_state, "packing_format_for_cpu", False):
if getattr(quant_state, "packing_format_for_cpu", False) and not is_pbf4:
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
Expand Down
10 changes: 8 additions & 2 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def _(
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
quant_type in ("nf4", "fp4", "pbf4"),
lambda: f"quant_type must be nf4, fp4, or pbf4, got {quant_type}",
)
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
Expand All @@ -160,7 +163,10 @@ def _(
# Odd shape is not supported by this kernel; fallback to generic implementation
shape_fallback = shape[-1] % 2 != 0

if avx512_fallback or shape_fallback:
# No native pbf4 CPU kernel — route to the LUT-driven default impl.
pbf4_fallback = quant_type == "pbf4"

if avx512_fallback or shape_fallback or pbf4_fallback:
from ..default.ops import _dequantize_4bit_impl

return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
Expand Down
16 changes: 14 additions & 2 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def _(
A = A.contiguous()
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(quant_type in ["fp4", "nf4", "pbf4"])
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
Expand All @@ -337,16 +337,22 @@ def _(
if A.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(*args)
elif quant_type == "pbf4":
lib.cquantize_blockwise_bf16_pbf4(*args)
else:
lib.cquantize_blockwise_bf16_nf4(*args)
elif A.dtype == torch.float16:
if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(*args)
elif quant_type == "pbf4":
lib.cquantize_blockwise_fp16_pbf4(*args)
else:
lib.cquantize_blockwise_fp16_nf4(*args)
elif A.dtype == torch.float32:
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(*args)
elif quant_type == "pbf4":
lib.cquantize_blockwise_fp32_pbf4(*args)
else:
lib.cquantize_blockwise_fp32_nf4(*args)

Expand Down Expand Up @@ -393,7 +399,7 @@ def _dequantize_4bit_impl(
A = A.contiguous()
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(quant_type in ["fp4", "nf4", "pbf4"])
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
Expand All @@ -413,16 +419,22 @@ def _dequantize_4bit_impl(
if out.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
elif quant_type == "pbf4":
lib.cdequantize_blockwise_bf16_pbf4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif out.dtype == torch.float16:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
elif quant_type == "pbf4":
lib.cdequantize_blockwise_fp16_pbf4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
elif quant_type == "pbf4":
lib.cdequantize_blockwise_fp32_pbf4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)

Expand Down
41 changes: 35 additions & 6 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,28 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
return out


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
def _quantize_4bit_impl(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
quant_type in ("nf4", "fp4", "pbf4"),
lambda: f"quant_type must be nf4, fp4, or pbf4, got {quant_type}",
)
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

if quant_type == "pbf4":
# PBF4 uses a fixed LUT derived from the PBF8 spine — same shape every
# call, so the op-level path can produce correct (packed, absmax)
# without any extra metadata. Python fallback for non-CUDA backends.
from ..._pbf4 import PBF_MX_LUT, quantize_pbf4_blockwise

packed, absmax, _ = quantize_pbf4_blockwise(A, blocksize, quant_storage, lut=PBF_MX_LUT)
return packed, absmax

n = A.numel()
full_blocks = n // blocksize
rem = n % blocksize
Expand Down Expand Up @@ -262,6 +273,13 @@ def _(
return packed, absmax.float()


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage)


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
Expand All @@ -270,6 +288,11 @@ def _dequantize_4bit_impl(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
if quant_type == "pbf4":
from ..._pbf4 import PBF_MX_LUT, dequantize_pbf4_blockwise

return dequantize_pbf4_blockwise(A, absmax, PBF_MX_LUT, blocksize, shape, dtype)

# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
Expand Down Expand Up @@ -318,7 +341,10 @@ def _(
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
quant_type in ("nf4", "fp4", "pbf4"),
lambda: f"quant_type must be nf4, fp4, or pbf4, got {quant_type}",
)
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
Expand All @@ -336,8 +362,11 @@ def _(
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# Applied from dequantize_4bit
quant_type = "fp4" if code[1] > 0 else "nf4"
# Recover quant_type from the LUT's distinctive ``code[1]``:
# fp4 ≈ +0.0052 (small positive); nf4 ≈ -0.696. (PBF4 has per-tensor
# calibrated LUTs that can't reach this op-level path — see
# ``functional.dequantize_4bit`` which short-circuits PBF4 directly.)
quant_type = "fp4" if float(code[1]) > 0 else "nf4"
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype)

return torch.nn.functional.linear(
Expand Down
Loading