diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index aed403ad87b..c495aafa53b 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -640,6 +640,12 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") + def prune_old_next_inputs(self, layer_idx: int) -> None: + """Delete the next_inputs.pt two layers back, keeping only the last two on disk.""" + old = os.path.join(_layer_dir(self.checkpoint_dir, layer_idx - 2), "next_inputs.pt") + if os.path.isfile(old): + os.remove(old) + def save( self, layer_idx: int, @@ -650,6 +656,11 @@ def save( ) -> None: """Snapshot layer state and write checkpoint to disk in one step. + After the manifest commits, the now-redundant ``next_inputs.pt`` for + the second-most-recent layer is deleted, keeping at most two activation + copies on disk (the just-consumed one for ``layer_idx`` and the + just-written one for ``layer_idx + 1``). + Args: layer_idx: Index of the layer just calibrated. layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). @@ -680,5 +691,8 @@ def save( _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, self.num_layers, ) + + self.prune_old_next_inputs(layer_idx) + suffix = " (final)" if next_layer_inputs is None else "" print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") diff --git a/tests/unit/torch/quantization/test_sequential_checkpoint.py b/tests/unit/torch/quantization/test_sequential_checkpoint.py index 0e592a68c75..d494dd0a698 100644 --- a/tests/unit/torch/quantization/test_sequential_checkpoint.py +++ b/tests/unit/torch/quantization/test_sequential_checkpoint.py @@ -19,6 +19,7 @@ import os from types import SimpleNamespace +import pytest import torch import torch.nn as nn @@ -98,30 +99,96 @@ def test_full_run_creates_checkpoints(monkeypatch, tmp_path): assert os.path.isfile(os.path.join(layer_dir, "weights.pt")) assert os.path.isfile(os.path.join(layer_dir, "quantizer_state.pt")) assert os.path.isfile(os.path.join(layer_dir, "output_meta.pt")) - # All layers except the last should have next_inputs - assert os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) + # Pruning trace with n_layers=3: + # after layer 0: cutoff = -2, no prune + # after layer 1: cutoff = -1, no prune + # after layer 2: cutoff = 0, delete layer_0000/next_inputs.pt + # Final layer never has a next_inputs.pt of its own. + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) assert os.path.isfile(os.path.join(ckpt_dir, "layer_0001", "next_inputs.pt")) assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0002", "next_inputs.pt")) -def test_resume_matches_full_run(monkeypatch, tmp_path): - """Resume from a truncated checkpoint produces the same final weights as a full run.""" +def test_only_last_two_next_inputs_kept(monkeypatch, tmp_path): + """After a multi-layer run, only the most recent ``next_inputs.pt`` is + retained; older layer dirs keep weights/qstate/meta but have their + ``next_inputs.pt`` pruned.""" _register_test_discoverer(monkeypatch) + n_layers = 5 ckpt_dir = str(tmp_path / "ckpt") + model, forward_loop = _make_model_and_forward(n_layers=n_layers) + layerwise_calibrate(model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + # Pruning trace for n_layers=5: + # after layer 0: cutoff = -2, no prune + # after layer 1: cutoff = -1, no prune + # after layer 2: cutoff = 0, delete layer_0000/next_inputs.pt + # after layer 3: cutoff = 1, delete layer_0001/next_inputs.pt + # after layer 4: cutoff = 2, delete layer_0002/next_inputs.pt + # Layer 3 retains its next_inputs.pt (the resume point for layer 4); + # layer 4 is the final layer and never has one. + for i in range(3): + assert not os.path.isfile(os.path.join(ckpt_dir, f"layer_{i:04d}", "next_inputs.pt")), ( + f"layer_{i:04d}/next_inputs.pt should have been pruned" + ) + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0003", "next_inputs.pt")) + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0004", "next_inputs.pt")) + + # Static per-layer files survive pruning. + for i in range(n_layers): + for fname in ("weights.pt", "quantizer_state.pt", "output_meta.pt"): + assert os.path.isfile(os.path.join(ckpt_dir, f"layer_{i:04d}", fname)), ( + f"layer_{i:04d}/{fname} should be retained" + ) + + +def test_resume_matches_full_run(monkeypatch, tmp_path): + """Resume after a simulated crash matches a full-run result. + + The crash is injected by raising in the calibration function partway + through, leaving the checkpoint in the same state a real crash would: + the manifest points to the last successfully saved layer and that + layer's next_inputs.pt is still on disk (pruning hasn't caught up to it). + """ + _register_test_discoverer(monkeypatch) - # Full reference run + # Reference: a complete run in its own directory. + ref_dir = str(tmp_path / "ref") ref_model, forward_loop = _make_model_and_forward(n_layers=3) - layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ref_dir) ref_weights = {n: p.clone() for n, p in ref_model.named_parameters()} - # Simulate crash after layer 0: truncate manifest - manifest_path = os.path.join(ckpt_dir, "manifest.json") - with open(manifest_path, "w") as f: - json.dump({"last_completed_layer": 0, "num_layers": 3}, f) + # Crash run: in a fresh dir, raise during the second layer's calibration so + # only layer 0 has been saved when the calibration loop unwinds. + resume_dir = str(tmp_path / "resume") + call_count = {"n": 0} + + def _crash_after_first_layer_calib(layer, forward_loop, **kwargs): + call_count["n"] += 1 + if call_count["n"] > 1: + raise RuntimeError("simulated crash before layer 1 completes") + _dummy_calib_func(layer, forward_loop, **kwargs) + + crash_model, crash_forward = _make_model_and_forward(n_layers=3) + with pytest.raises(RuntimeError, match="simulated crash"): + layerwise_calibrate( + crash_model, + crash_forward, + _crash_after_first_layer_calib, + checkpoint_dir=resume_dir, + ) + + # Mid-crash invariant: manifest at layer 0; its next_inputs.pt is intact + # (cutoff = -2, no prune has happened yet). + with open(os.path.join(resume_dir, "manifest.json")) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == 0 + assert manifest["num_layers"] == 3 + assert os.path.isfile(os.path.join(resume_dir, "layer_0000", "next_inputs.pt")) - # Resume from a fresh model + # Resume on a fresh model in the same checkpoint dir. resumed_model, forward_loop = _make_model_and_forward(n_layers=3) - layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=resume_dir) for name, ref_param in ref_weights.items(): resumed_param = dict(resumed_model.named_parameters())[name]