Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,29 @@ if(BUILD_CUDA)

list(APPEND SRC_FILES ${GPU_FILES})

# 4-bit GEMM SIMT and dispatch compile for all selected archs.
# sm75/sm80+ MMA files only compile if those archs are selected.
set(_cc_all ${COMPUTE_CAPABILITY} ${_LATEST_CAPABILITY})
list(APPEND SRC_FILES csrc/gemm_4bit_simt.cu csrc/gemm_4bit.cu)

if(75 IN_LIST _cc_all)
# Builds only on sm75
list(APPEND SRC_FILES csrc/gemm_4bit_sm75.cu)
add_compile_definitions(BNB_HAS_GEMM4BIT_SM75)
endif()

set(_cc_sm80plus)
foreach(_cc IN LISTS _cc_all)
if(_cc GREATER_EQUAL 80)
list(APPEND _cc_sm80plus ${_cc})
endif()
endforeach()
if(_cc_sm80plus)
# Builds only on sm80+
list(APPEND SRC_FILES csrc/gemm_4bit_sm80.cu)
add_compile_definitions(BNB_HAS_GEMM4BIT_SM80)
endif()

string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_HIP)
Expand Down
59 changes: 59 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,65 @@ def _(
return out, absmax


torch.library.define(
"bitsandbytes::gemm_4bit",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, int blocksize, str quant_type, "
"Tensor? bias=None, Tensor? absmax_8bit=None, Tensor? absmax_code=None, Tensor? absmax_offset=None) -> Tensor",
)


@register_fake("bitsandbytes::gemm_4bit")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(len(shapeB) == 2, lambda: f"shapeB must be 2D [N, K], got {list(shapeB)}")
torch._check(A.shape[-1] == shapeB[1], lambda: f"A inner dim ({A.shape[-1]}) must match shapeB ({shapeB[1]})")
torch._check(
A.dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32),
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(blocksize in [32, 64, 128, 256, 512, 1024, 2048, 4096], lambda: f"invalid blocksize {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}")
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
if absmax_8bit is not None:
torch._check(absmax_8bit.ndim == 1, lambda: f"absmax_8bit must be 1D, got {absmax_8bit.ndim}D")
torch._check(absmax_8bit.dtype == torch.uint8, lambda: f"absmax_8bit must be uint8, got {absmax_8bit.dtype}")
torch._check(absmax_code is not None, lambda: "absmax_code required when absmax_8bit is provided")
torch._check(absmax_code.ndim == 1, lambda: f"absmax_code must be 1D, got {absmax_code.ndim}D")
torch._check(
absmax_code.shape[0] == 256, lambda: f"absmax_code must have 256 entries, got {absmax_code.shape[0]}"
)
torch._check(
absmax_code.dtype == torch.float32, lambda: f"absmax_code must be float32, got {absmax_code.dtype}"
)
torch._check(absmax_offset is not None, lambda: "absmax_offset required when absmax_8bit is provided")
torch._check(
absmax_offset.ndim == 0, lambda: f"absmax_offset must be a scalar (0-dim), got {absmax_offset.ndim}D"
)
torch._check(
absmax_offset.dtype == torch.float32, lambda: f"absmax_offset must be float32, got {absmax_offset.dtype}"
)
if bias is not None:
torch._check(bias.ndim == 1, lambda: f"bias must be 1D, got {bias.ndim}D")
torch._check(bias.shape[0] == shapeB[0], lambda: f"bias length ({bias.shape[0]}) must match N ({shapeB[0]})")
torch._check(bias.dtype == A.dtype, lambda: f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})")
N = shapeB[0]
return torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)


torch.library.define(
"bitsandbytes::dequantize_blockwise",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
Expand Down
131 changes: 108 additions & 23 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,37 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)

# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# Normalize to canonical [(N*K+1)//2, 1]. Packed weights are always contiguous
# in this orientation (B.t() callers get strides [1,1], still compatible).
# quant_state.shape is the source of truth for N and K.
B = B.view(-1, 1)

if not quant_state.nested:
output = torch.ops.bitsandbytes.gemm_4bit.default(
A,
B,
quant_state.shape,
quant_state.absmax,
quant_state.blocksize,
quant_state.quant_type,
bias=bias,
)
elif quant_state.state2.blocksize == 256:
output = torch.ops.bitsandbytes.gemm_4bit.default(
A,
B,
quant_state.shape,
quant_state.state2.absmax,
quant_state.blocksize,
quant_state.quant_type,
bias=bias,
absmax_8bit=quant_state.absmax,
absmax_code=quant_state.state2.code,
absmax_offset=quant_state.offset,
)
else:
raise NotImplementedError("nested quantization with state2.blocksize != 256 is not supported")

if out is not None:
out.copy_(output)
output = out
Expand Down Expand Up @@ -351,7 +379,9 @@ def backward(ctx, grad_output):
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
# B in ctx.tensors is already in canonical [(N*K+1)//2, 1] form (normalized in forward).
# dequantize returns [N, K]; matmul(grad_output[M,N], [N,K]) = grad_A[M,K].
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype))

return grad_A, grad_B, None, grad_bias, None

Expand Down Expand Up @@ -381,26 +411,81 @@ def matmul_4bit(
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
if A.device.type == "cpu":
if getattr(quant_state, "packing_format_for_cpu", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
if quant_state is None:
raise ValueError("quant_state is required")
if len(quant_state.shape) != 2:
raise ValueError("matmul_4bit: quant_state.shape must be 2D [N, K]")

# packing_format_for_cpu uses a different memory layout optimized for AVX512BF16.
# This flag is only set for inference (weight conversion happens at eval time).
# The underlying kernel supports any M via tiled GEMM despite the gemv name.
if A.device.type == "cpu" and getattr(quant_state, "packing_format_for_cpu", False):
result = F.gemv_4bit(A, B, out=out, state=quant_state)
if bias is not None:
result += bias
return result

# Normalize B to canonical [(N*K+1)//2, 1]. Packed weights are always contiguous
# in this orientation (B.t() callers get strides [1,1], still compatible).
# quant_state.shape is the source of truth for N and K.
B = B.view(-1, 1)

K = A.shape[-1]

# Weight is in [K, N] orientation when A's inner dim matches shape[0] not shape[1].
# Square weights (K==N) are ambiguous and treated as [N, K].
if K == quant_state.shape[0] and K != quant_state.shape[1]:
if not _is_compiling():
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
f"matmul_4bit: weight was quantized from a [K, N] tensor (quant_state.shape={list(quant_state.shape)}). "
"Re-quantize from the weight in [N, K] (out_features, in_features) orientation. "
"This will be an error in a future version.",
DeprecationWarning,
stacklevel=2,
)
B_dq = F.dequantize_4bit(B, quant_state).to(A.dtype)
result = torch.nn.functional.linear(A, B_dq.t(), bias)
if out is not None:
out.copy_(result)
return out
return result

needs_grad = torch.is_grad_enabled() and (A.requires_grad or (bias is not None and bias.requires_grad))
if not needs_grad:
A_numel = A.numel()
if A_numel == 0:
if out is not None:
return out
return torch.empty((*A.shape[:-1], quant_state.shape[0]), dtype=A.dtype, device=A.device)

if not quant_state.nested:
result = torch.ops.bitsandbytes.gemm_4bit.default(
A,
B,
quant_state.shape,
quant_state.absmax,
quant_state.blocksize,
quant_state.quant_type,
bias=bias,
)
elif quant_state.state2.blocksize == 256:
result = torch.ops.bitsandbytes.gemm_4bit.default(
A,
B,
quant_state.shape,
quant_state.state2.absmax,
quant_state.blocksize,
quant_state.quant_type,
bias=bias,
absmax_8bit=quant_state.absmax,
absmax_code=quant_state.state2.code,
absmax_offset=quant_state.offset,
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
raise NotImplementedError("nested quantization with state2.blocksize != 256 is not supported")
if out is not None:
out.copy_(result)
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
return result

return MatMul4Bit.apply(A, B, out, bias, quant_state)
Loading
Loading