Summary
bitsandbytes.nn.Linear4bit overrides _save_to_state_dict (bitsandbytes/nn/modules.py:593) to write QuantState components alongside the packed weight:
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars) # weight + bias
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
Resulting state_dict keys for a Linear4bit (with compress_statistics=True):
weight (packed 4-bit tensor)
weight.absmax
weight.quant_map
weight.nested_absmax
weight.nested_quant_map
weight.quant_state.bitsandbytes__nf4 (or __fp4)
Linear4bit does not define a corresponding _load_from_state_dict. It inherits nn.Linear._load_from_state_dict, which only consumes weight and bias. The QuantState keys land in unexpected_keys during load. With strict=True (the model.load_state_dict() default) this raises RuntimeError.
This is asymmetric relative to Linear8bitLt, which defines both _save_to_state_dict (line 1095) and _load_from_state_dict (line 1119), with the latter explicitly walking unexpected_keys to consume the SCB tensor and remove it from the unexpected list.
Reproducer
import torch
import bitsandbytes as bnb
src = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4',
compute_dtype=torch.bfloat16,
compress_statistics=True)
src = src.to('cuda') # triggers quantize, populates quant_state
sd = src.state_dict()
print("state_dict keys:", list(sd.keys()))
# ['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax',
# 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__nf4']
dst = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4',
compute_dtype=torch.bfloat16,
compress_statistics=True)
dst = dst.to('cuda')
# Fails with strict=True (the default):
dst.load_state_dict(sd)
# RuntimeError: Error(s) in loading state_dict for Linear4bit:
# Unexpected key(s) in state_dict: "weight.absmax", "weight.quant_map",
# "weight.nested_absmax", "weight.nested_quant_map",
# "weight.quant_state.bitsandbytes__nf4"
With strict=False, the QuantState components are silently ignored: dst.weight.data becomes the packed bytes from src, but dst.weight.quant_state is the one populated during dst.to('cuda') — not the one from sd. If src and dst quantized with different inputs (very common, e.g. loading a checkpoint into a freshly-initialized module), dst is left in a silently-corrupt state: packed bytes from one model, QuantState scalars from another.
Why this hasn't been reported as a user-facing bug
Standard workflows bypass model.load_state_dict() for bnb-quantized checkpoints:
- HF Transformers
from_pretrained for pre-quantized bnb checkpoints: uses bnb.nn.Params4bit.from_prequantized(data, quantized_stats, device=...) directly via transformers/quantizers/quantizer_bnb_4bit.py::create_quantized_param. The quantized_stats dict is built by walking state_dict keys with prefix param_name + ".". Bypasses _load_from_state_dict entirely.
- PEFT save/load:
get_peft_model_state_dict filters to LoRA-only keys ("lora_" in k); QuantState keys never enter the PEFT round-trip.
- Accelerate's
fsdp2_load_full_state_dict: has its own broadcast + assign=True path; PR #3982 added key-based matching that filters non-Params keys.
Paths that do hit the asymmetry:
torch.distributed.checkpoint (DCP) when the loaded state_dict contains QuantState keys
- Custom training loops that save via
model.state_dict() and resume via model.load_state_dict()
- The docstring example at
nn/modules.py:522 (quantized_model.load_state_dict(fp16_model.state_dict())) works only because the source is fp16 (no QuantState keys to be unexpected) — masking the bnb→bnb round-trip case
Discovery context
Surfaced while characterizing FSDP2 + Params4bit forward correctness on Jetson Orin Nano Super (sm_87) — the same investigation that led to #1945. The investigation needed to round-trip QuantState through state_dict() for a pre-shard-broadcast pattern; that worked on the save side and failed on the load side, leading to this finding. The two issues are unrelated mechanically — this one is a missing override, #1945 is a PyTorch FSDP2 NaN-canonicalization bug — but they share discovery provenance.
Suggested fix
Implement _load_from_state_dict paralleling Linear8bitLt's pattern. Sketch:
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs,
):
# Collect QuantState components in state_dict for this prefix
qs_keys_to_consume = []
quantized_stats = {}
weight_dot_prefix = prefix + "weight."
for k in list(state_dict.keys()):
if k.startswith(weight_dot_prefix):
qs_keys_to_consume.append(k)
# store as `weight.<subkey>` for from_prequantized
quantized_stats[k[len(prefix):]] = state_dict[k]
# Standard nn.Linear path consumes 'weight' and 'bias'
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs,
)
# If QuantState components are present, reconstruct via from_prequantized
if quantized_stats:
for k in qs_keys_to_consume:
if k in unexpected_keys:
unexpected_keys.remove(k)
weight_data = self.weight.data # already loaded by super()
self.weight = Params4bit.from_prequantized(
data=weight_data,
quantized_stats=quantized_stats,
requires_grad=False,
device=weight_data.device,
module=self,
)
Edge cases:
- If QuantState keys are absent (e.g., the existing fp16→bnb docstring example),
quantized_stats is empty and the function falls through to standard nn.Linear behavior — no regression.
- The
weight_dot_prefix match cleanly delimits per-Linear keys; the state_dict passed to _load_from_state_dict is already filtered to the current module's prefix by nn.Module.load_state_dict()'s recursion.
Test plan (for the fix)
Severity
Latent — currently masked. Does not surface in standard HF + PEFT workflows because they bypass load_state_dict. But it's a contract violation between _save_to_state_dict and the standard nn.Module load mechanism, and surfaces for:
torch.distributed.checkpoint (DCP) usage
- Custom training-loop checkpoint code
- Any tooling that round-trips through
model.state_dict() + model.load_state_dict()
Adjacent PRs (for context, do not address this)
Neither addresses the missing _load_from_state_dict.
Happy to file a PR with the fix sketch + tests if useful.
Summary
bitsandbytes.nn.Linear4bitoverrides_save_to_state_dict(bitsandbytes/nn/modules.py:593) to write QuantState components alongside the packed weight:Resulting state_dict keys for a Linear4bit (with
compress_statistics=True):weight(packed 4-bit tensor)weight.absmaxweight.quant_mapweight.nested_absmaxweight.nested_quant_mapweight.quant_state.bitsandbytes__nf4(or__fp4)Linear4bitdoes not define a corresponding_load_from_state_dict. It inheritsnn.Linear._load_from_state_dict, which only consumesweightandbias. The QuantState keys land inunexpected_keysduring load. Withstrict=True(themodel.load_state_dict()default) this raisesRuntimeError.This is asymmetric relative to
Linear8bitLt, which defines both_save_to_state_dict(line 1095) and_load_from_state_dict(line 1119), with the latter explicitly walkingunexpected_keysto consume the SCB tensor and remove it from the unexpected list.Reproducer
With
strict=False, the QuantState components are silently ignored:dst.weight.databecomes the packed bytes fromsrc, butdst.weight.quant_stateis the one populated duringdst.to('cuda')— not the one fromsd. Ifsrcanddstquantized with different inputs (very common, e.g. loading a checkpoint into a freshly-initialized module),dstis left in a silently-corrupt state: packed bytes from one model, QuantState scalars from another.Why this hasn't been reported as a user-facing bug
Standard workflows bypass
model.load_state_dict()for bnb-quantized checkpoints:from_pretrainedfor pre-quantized bnb checkpoints: usesbnb.nn.Params4bit.from_prequantized(data, quantized_stats, device=...)directly viatransformers/quantizers/quantizer_bnb_4bit.py::create_quantized_param. Thequantized_statsdict is built by walking state_dict keys with prefixparam_name + ".". Bypasses_load_from_state_dictentirely.get_peft_model_state_dictfilters to LoRA-only keys ("lora_" in k); QuantState keys never enter the PEFT round-trip.fsdp2_load_full_state_dict: has its own broadcast +assign=Truepath; PR #3982 added key-based matching that filters non-Params keys.Paths that do hit the asymmetry:
torch.distributed.checkpoint(DCP) when the loaded state_dict contains QuantState keysmodel.state_dict()and resume viamodel.load_state_dict()nn/modules.py:522(quantized_model.load_state_dict(fp16_model.state_dict())) works only because the source is fp16 (no QuantState keys to be unexpected) — masking the bnb→bnb round-trip caseDiscovery context
Surfaced while characterizing FSDP2 +
Params4bitforward correctness on Jetson Orin Nano Super (sm_87) — the same investigation that led to #1945. The investigation needed to round-trip QuantState throughstate_dict()for a pre-shard-broadcast pattern; that worked on the save side and failed on the load side, leading to this finding. The two issues are unrelated mechanically — this one is a missing override, #1945 is a PyTorch FSDP2 NaN-canonicalization bug — but they share discovery provenance.Suggested fix
Implement
_load_from_state_dictparallelingLinear8bitLt's pattern. Sketch:Edge cases:
quantized_statsis empty and the function falls through to standardnn.Linearbehavior — no regression.weight_dot_prefixmatch cleanly delimits per-Linear keys; the state_dict passed to_load_from_state_dictis already filtered to the current module's prefix bynn.Module.load_state_dict()'s recursion.Test plan (for the fix)
m1.state_dict() → m2.load_state_dict(strict=True)on CPU then CUDA Linear4bit; assertbnb.matmul_4bitproduces identical output on synthetic input.compress_statisticson/off,quant_storage∈ {uint8, bf16, fp16}, with/without bias.strict=Falsebehavior with partial QuantState keys (fall back to existing_quantizeflow on.to(device)if applicable).Severity
Latent — currently masked. Does not surface in standard HF + PEFT workflows because they bypass
load_state_dict. But it's a contract violation between_save_to_state_dictand the standardnn.Moduleload mechanism, and surfaces for:torch.distributed.checkpoint(DCP) usagemodel.state_dict()+model.load_state_dict()Adjacent PRs (for context, do not address this)
Params4bit.__getattr__proxy for FSDP_get_fqns()traversal (state_dict traversal, not load).__getattr__with@propertydescriptors (motivated bytorch.compilegraph breaks, preserved FSDP traversal as a side effect).Neither addresses the missing
_load_from_state_dict.Happy to file a PR with the fix sketch + tests if useful.