Skip to content

Linear4bit._save_to_state_dict writes QuantState keys but no _load_from_state_dict consumes them (asymmetric serialization) #1946

@neil-the-nowledgeable

Description

@neil-the-nowledgeable

Summary

bitsandbytes.nn.Linear4bit overrides _save_to_state_dict (bitsandbytes/nn/modules.py:593) to write QuantState components alongside the packed weight:

def _save_to_state_dict(self, destination, prefix, keep_vars):
    super()._save_to_state_dict(destination, prefix, keep_vars)  # weight + bias
    if getattr(self.weight, "quant_state", None) is not None:
        for k, v in self.weight.quant_state.as_dict(packed=True).items():
            destination[prefix + "weight." + k] = v if keep_vars else v.detach()

Resulting state_dict keys for a Linear4bit (with compress_statistics=True):

  • weight (packed 4-bit tensor)
  • weight.absmax
  • weight.quant_map
  • weight.nested_absmax
  • weight.nested_quant_map
  • weight.quant_state.bitsandbytes__nf4 (or __fp4)

Linear4bit does not define a corresponding _load_from_state_dict. It inherits nn.Linear._load_from_state_dict, which only consumes weight and bias. The QuantState keys land in unexpected_keys during load. With strict=True (the model.load_state_dict() default) this raises RuntimeError.

This is asymmetric relative to Linear8bitLt, which defines both _save_to_state_dict (line 1095) and _load_from_state_dict (line 1119), with the latter explicitly walking unexpected_keys to consume the SCB tensor and remove it from the unexpected list.

Reproducer

import torch
import bitsandbytes as bnb

src = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4',
                       compute_dtype=torch.bfloat16,
                       compress_statistics=True)
src = src.to('cuda')  # triggers quantize, populates quant_state

sd = src.state_dict()
print("state_dict keys:", list(sd.keys()))
# ['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax',
#  'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__nf4']

dst = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4',
                       compute_dtype=torch.bfloat16,
                       compress_statistics=True)
dst = dst.to('cuda')

# Fails with strict=True (the default):
dst.load_state_dict(sd)
# RuntimeError: Error(s) in loading state_dict for Linear4bit:
#     Unexpected key(s) in state_dict: "weight.absmax", "weight.quant_map",
#     "weight.nested_absmax", "weight.nested_quant_map",
#     "weight.quant_state.bitsandbytes__nf4"

With strict=False, the QuantState components are silently ignored: dst.weight.data becomes the packed bytes from src, but dst.weight.quant_state is the one populated during dst.to('cuda') — not the one from sd. If src and dst quantized with different inputs (very common, e.g. loading a checkpoint into a freshly-initialized module), dst is left in a silently-corrupt state: packed bytes from one model, QuantState scalars from another.

Why this hasn't been reported as a user-facing bug

Standard workflows bypass model.load_state_dict() for bnb-quantized checkpoints:

  1. HF Transformers from_pretrained for pre-quantized bnb checkpoints: uses bnb.nn.Params4bit.from_prequantized(data, quantized_stats, device=...) directly via transformers/quantizers/quantizer_bnb_4bit.py::create_quantized_param. The quantized_stats dict is built by walking state_dict keys with prefix param_name + ".". Bypasses _load_from_state_dict entirely.
  2. PEFT save/load: get_peft_model_state_dict filters to LoRA-only keys ("lora_" in k); QuantState keys never enter the PEFT round-trip.
  3. Accelerate's fsdp2_load_full_state_dict: has its own broadcast + assign=True path; PR #3982 added key-based matching that filters non-Params keys.

Paths that do hit the asymmetry:

  • torch.distributed.checkpoint (DCP) when the loaded state_dict contains QuantState keys
  • Custom training loops that save via model.state_dict() and resume via model.load_state_dict()
  • The docstring example at nn/modules.py:522 (quantized_model.load_state_dict(fp16_model.state_dict())) works only because the source is fp16 (no QuantState keys to be unexpected) — masking the bnb→bnb round-trip case

Discovery context

Surfaced while characterizing FSDP2 + Params4bit forward correctness on Jetson Orin Nano Super (sm_87) — the same investigation that led to #1945. The investigation needed to round-trip QuantState through state_dict() for a pre-shard-broadcast pattern; that worked on the save side and failed on the load side, leading to this finding. The two issues are unrelated mechanically — this one is a missing override, #1945 is a PyTorch FSDP2 NaN-canonicalization bug — but they share discovery provenance.

Suggested fix

Implement _load_from_state_dict paralleling Linear8bitLt's pattern. Sketch:

def _load_from_state_dict(
    self, state_dict, prefix, local_metadata, strict,
    missing_keys, unexpected_keys, error_msgs,
):
    # Collect QuantState components in state_dict for this prefix
    qs_keys_to_consume = []
    quantized_stats = {}
    weight_dot_prefix = prefix + "weight."
    for k in list(state_dict.keys()):
        if k.startswith(weight_dot_prefix):
            qs_keys_to_consume.append(k)
            # store as `weight.<subkey>` for from_prequantized
            quantized_stats[k[len(prefix):]] = state_dict[k]

    # Standard nn.Linear path consumes 'weight' and 'bias'
    super()._load_from_state_dict(
        state_dict, prefix, local_metadata, strict,
        missing_keys, unexpected_keys, error_msgs,
    )

    # If QuantState components are present, reconstruct via from_prequantized
    if quantized_stats:
        for k in qs_keys_to_consume:
            if k in unexpected_keys:
                unexpected_keys.remove(k)

        weight_data = self.weight.data  # already loaded by super()
        self.weight = Params4bit.from_prequantized(
            data=weight_data,
            quantized_stats=quantized_stats,
            requires_grad=False,
            device=weight_data.device,
            module=self,
        )

Edge cases:

  • If QuantState keys are absent (e.g., the existing fp16→bnb docstring example), quantized_stats is empty and the function falls through to standard nn.Linear behavior — no regression.
  • The weight_dot_prefix match cleanly delimits per-Linear keys; the state_dict passed to _load_from_state_dict is already filtered to the current module's prefix by nn.Module.load_state_dict()'s recursion.

Test plan (for the fix)

  • Bnb→bnb round-trip: m1.state_dict() → m2.load_state_dict(strict=True) on CPU then CUDA Linear4bit; assert bnb.matmul_4bit produces identical output on synthetic input.
  • Existing fp16→bnb docstring example continues to pass (no regression).
  • Variant matrix: NF4 / FP4, compress_statistics on/off, quant_storage ∈ {uint8, bf16, fp16}, with/without bias.
  • strict=False behavior with partial QuantState keys (fall back to existing _quantize flow on .to(device) if applicable).

Severity

Latent — currently masked. Does not surface in standard HF + PEFT workflows because they bypass load_state_dict. But it's a contract violation between _save_to_state_dict and the standard nn.Module load mechanism, and surfaces for:

  • torch.distributed.checkpoint (DCP) usage
  • Custom training-loop checkpoint code
  • Any tooling that round-trips through model.state_dict() + model.load_state_dict()

Adjacent PRs (for context, do not address this)

Neither addresses the missing _load_from_state_dict.

Happy to file a PR with the fix sketch + tests if useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions