Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions modelopt/torch/quantization/utils/layerwise_calib.py
Copy link
Copy Markdown
Contributor

@realAsma realAsma May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 should we have just one global next_inputs.pt and keep updating it instead of saving next_inputs.pt to each layer_folder/ and deleting it? (So no per-layer next_inputs.pt)

Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,12 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None:

print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers")

def prune_old_next_inputs(self, layer_idx: int) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def prune_old_next_inputs(self, layer_idx: int) -> None:
def _prune_old_next_inputs(self, layer_idx: int) -> None:

"""Delete the next_inputs.pt two layers back, keeping only the last two on disk."""
old = os.path.join(_layer_dir(self.checkpoint_dir, layer_idx - 2), "next_inputs.pt")
if os.path.isfile(old):
os.remove(old)

def save(
self,
layer_idx: int,
Expand All @@ -650,6 +656,11 @@ def save(
) -> None:
"""Snapshot layer state and write checkpoint to disk in one step.

After the manifest commits, the now-redundant ``next_inputs.pt`` for
the second-most-recent layer is deleted, keeping at most two activation
copies on disk (the just-consumed one for ``layer_idx`` and the
just-written one for ``layer_idx + 1``).

Args:
layer_idx: Index of the layer just calibrated.
layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2).
Expand Down Expand Up @@ -680,5 +691,8 @@ def save(
_move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None,
self.num_layers,
)

self.prune_old_next_inputs(layer_idx)

suffix = " (final)" if next_layer_inputs is None else ""
print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}")
91 changes: 79 additions & 12 deletions tests/unit/torch/quantization/test_sequential_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
from types import SimpleNamespace

import pytest
import torch
import torch.nn as nn

Expand Down Expand Up @@ -98,30 +99,96 @@ def test_full_run_creates_checkpoints(monkeypatch, tmp_path):
assert os.path.isfile(os.path.join(layer_dir, "weights.pt"))
assert os.path.isfile(os.path.join(layer_dir, "quantizer_state.pt"))
assert os.path.isfile(os.path.join(layer_dir, "output_meta.pt"))
# All layers except the last should have next_inputs
assert os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt"))
# Pruning trace with n_layers=3:
# after layer 0: cutoff = -2, no prune
# after layer 1: cutoff = -1, no prune
# after layer 2: cutoff = 0, delete layer_0000/next_inputs.pt
# Final layer never has a next_inputs.pt of its own.
assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt"))
assert os.path.isfile(os.path.join(ckpt_dir, "layer_0001", "next_inputs.pt"))
assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0002", "next_inputs.pt"))


def test_resume_matches_full_run(monkeypatch, tmp_path):
"""Resume from a truncated checkpoint produces the same final weights as a full run."""
def test_only_last_two_next_inputs_kept(monkeypatch, tmp_path):
"""After a multi-layer run, only the most recent ``next_inputs.pt`` is
retained; older layer dirs keep weights/qstate/meta but have their
``next_inputs.pt`` pruned."""
_register_test_discoverer(monkeypatch)
n_layers = 5
ckpt_dir = str(tmp_path / "ckpt")
model, forward_loop = _make_model_and_forward(n_layers=n_layers)
layerwise_calibrate(model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir)

# Pruning trace for n_layers=5:
# after layer 0: cutoff = -2, no prune
# after layer 1: cutoff = -1, no prune
# after layer 2: cutoff = 0, delete layer_0000/next_inputs.pt
# after layer 3: cutoff = 1, delete layer_0001/next_inputs.pt
# after layer 4: cutoff = 2, delete layer_0002/next_inputs.pt
# Layer 3 retains its next_inputs.pt (the resume point for layer 4);
# layer 4 is the final layer and never has one.
for i in range(3):
assert not os.path.isfile(os.path.join(ckpt_dir, f"layer_{i:04d}", "next_inputs.pt")), (
f"layer_{i:04d}/next_inputs.pt should have been pruned"
)
assert os.path.isfile(os.path.join(ckpt_dir, "layer_0003", "next_inputs.pt"))
assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0004", "next_inputs.pt"))

# Static per-layer files survive pruning.
for i in range(n_layers):
for fname in ("weights.pt", "quantizer_state.pt", "output_meta.pt"):
assert os.path.isfile(os.path.join(ckpt_dir, f"layer_{i:04d}", fname)), (
f"layer_{i:04d}/{fname} should be retained"
)


def test_resume_matches_full_run(monkeypatch, tmp_path):
"""Resume after a simulated crash matches a full-run result.

The crash is injected by raising in the calibration function partway
through, leaving the checkpoint in the same state a real crash would:
the manifest points to the last successfully saved layer and that
layer's next_inputs.pt is still on disk (pruning hasn't caught up to it).
"""
_register_test_discoverer(monkeypatch)

# Full reference run
# Reference: a complete run in its own directory.
ref_dir = str(tmp_path / "ref")
ref_model, forward_loop = _make_model_and_forward(n_layers=3)
layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir)
layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ref_dir)
ref_weights = {n: p.clone() for n, p in ref_model.named_parameters()}

# Simulate crash after layer 0: truncate manifest
manifest_path = os.path.join(ckpt_dir, "manifest.json")
with open(manifest_path, "w") as f:
json.dump({"last_completed_layer": 0, "num_layers": 3}, f)
# Crash run: in a fresh dir, raise during the second layer's calibration so
# only layer 0 has been saved when the calibration loop unwinds.
resume_dir = str(tmp_path / "resume")
call_count = {"n": 0}

def _crash_after_first_layer_calib(layer, forward_loop, **kwargs):
call_count["n"] += 1
if call_count["n"] > 1:
raise RuntimeError("simulated crash before layer 1 completes")
_dummy_calib_func(layer, forward_loop, **kwargs)

crash_model, crash_forward = _make_model_and_forward(n_layers=3)
with pytest.raises(RuntimeError, match="simulated crash"):
layerwise_calibrate(
crash_model,
crash_forward,
_crash_after_first_layer_calib,
checkpoint_dir=resume_dir,
)

# Mid-crash invariant: manifest at layer 0; its next_inputs.pt is intact
# (cutoff = -2, no prune has happened yet).
with open(os.path.join(resume_dir, "manifest.json")) as f:
manifest = json.load(f)
assert manifest["last_completed_layer"] == 0
assert manifest["num_layers"] == 3
assert os.path.isfile(os.path.join(resume_dir, "layer_0000", "next_inputs.pt"))

# Resume from a fresh model
# Resume on a fresh model in the same checkpoint dir.
resumed_model, forward_loop = _make_model_and_forward(n_layers=3)
layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir)
layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=resume_dir)

for name, ref_param in ref_weights.items():
resumed_param = dict(resumed_model.named_parameters())[name]
Expand Down
Loading