From 8c60a38e2c848fccce54d68b85cb983a6cd1ae8c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 11 May 2026 21:33:41 +0000 Subject: [PATCH 1/4] save only last 2 layers Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 4 +- .../quantization/utils/layerwise_calib.py | 31 +++++-- .../test_sequential_checkpoint.py | 91 ++++++++++++++++--- 3 files changed, 107 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fe4c3f77ce6..f8ed63a051a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1648,7 +1648,9 @@ def layerwise_calibrate( If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints are saved after each layer completes. On restart, calibration resumes from - the last completed layer. + the last completed layer. To bound disk usage, only the two most recent + layers' ``next_inputs.pt`` are retained — older copies are pruned after + each successful save. """ checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index aed403ad87b..bdd8bce681d 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -23,6 +23,7 @@ from __future__ import annotations +import contextlib import json import os import shutil @@ -474,7 +475,7 @@ def _read_manifest(checkpoint_dir: str) -> dict | None: def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: - """Atomically write manifest.json.""" + """Atomically write manifest.json — the single commit point per iteration.""" path = os.path.join(checkpoint_dir, "manifest.json") tmp = path + ".tmp" with open(tmp, "w") as f: @@ -575,10 +576,11 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint return cls(checkpoint_dir, num_layers, start_layer=start) def setup_resume(self, layers: nn.ModuleList) -> list | None: - """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. + """Load ``output_meta`` for skip layers 0..K-1, return ``next_inputs`` for layer K. Sets ``output_meta`` on each already-calibrated layer so that - skip mode can produce correctly shaped dummy outputs. + skip mode can produce correctly shaped dummy outputs. The activations + are read from ``layer_{K-1:04d}/next_inputs.pt``. """ if self.start_layer == 0: return None @@ -595,10 +597,13 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None: meta = _remap_output_metadata_device(meta, layer_device) layers[i]._layerwise_calib.output_meta = meta - d = _layer_dir(self.checkpoint_dir, last_ckpt) - next_inputs_path = os.path.join(d, "next_inputs.pt") + next_inputs_path = os.path.join( + _layer_dir(self.checkpoint_dir, last_ckpt), "next_inputs.pt" + ) if not os.path.isfile(next_inputs_path): - raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}") + raise FileNotFoundError( + f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}: {next_inputs_path}" + ) # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False) resume_device = get_module_device(layers[self.start_layer]) @@ -650,6 +655,11 @@ def save( ) -> None: """Snapshot layer state and write checkpoint to disk in one step. + After the manifest commits, the now-redundant ``next_inputs.pt`` for + the second-most-recent layer is deleted, keeping at most two activation + copies on disk (the just-consumed one for ``layer_idx`` and the + just-written one for ``layer_idx + 1``). + Args: layer_idx: Index of the layer just calibrated. layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). @@ -680,5 +690,14 @@ def save( _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, self.num_layers, ) + + # Prune the next_inputs.pt for the layer two back — we only need the + # last two to support resume from the most recent commit point. Done + # after the manifest write so a crash here is harmless. + stale = layer_idx - 2 + if stale >= 0: + with contextlib.suppress(FileNotFoundError): + os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")) + suffix = " (final)" if next_layer_inputs is None else "" print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") diff --git a/tests/unit/torch/quantization/test_sequential_checkpoint.py b/tests/unit/torch/quantization/test_sequential_checkpoint.py index 0e592a68c75..d494dd0a698 100644 --- a/tests/unit/torch/quantization/test_sequential_checkpoint.py +++ b/tests/unit/torch/quantization/test_sequential_checkpoint.py @@ -19,6 +19,7 @@ import os from types import SimpleNamespace +import pytest import torch import torch.nn as nn @@ -98,30 +99,96 @@ def test_full_run_creates_checkpoints(monkeypatch, tmp_path): assert os.path.isfile(os.path.join(layer_dir, "weights.pt")) assert os.path.isfile(os.path.join(layer_dir, "quantizer_state.pt")) assert os.path.isfile(os.path.join(layer_dir, "output_meta.pt")) - # All layers except the last should have next_inputs - assert os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) + # Pruning trace with n_layers=3: + # after layer 0: cutoff = -2, no prune + # after layer 1: cutoff = -1, no prune + # after layer 2: cutoff = 0, delete layer_0000/next_inputs.pt + # Final layer never has a next_inputs.pt of its own. + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) assert os.path.isfile(os.path.join(ckpt_dir, "layer_0001", "next_inputs.pt")) assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0002", "next_inputs.pt")) -def test_resume_matches_full_run(monkeypatch, tmp_path): - """Resume from a truncated checkpoint produces the same final weights as a full run.""" +def test_only_last_two_next_inputs_kept(monkeypatch, tmp_path): + """After a multi-layer run, only the most recent ``next_inputs.pt`` is + retained; older layer dirs keep weights/qstate/meta but have their + ``next_inputs.pt`` pruned.""" _register_test_discoverer(monkeypatch) + n_layers = 5 ckpt_dir = str(tmp_path / "ckpt") + model, forward_loop = _make_model_and_forward(n_layers=n_layers) + layerwise_calibrate(model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + # Pruning trace for n_layers=5: + # after layer 0: cutoff = -2, no prune + # after layer 1: cutoff = -1, no prune + # after layer 2: cutoff = 0, delete layer_0000/next_inputs.pt + # after layer 3: cutoff = 1, delete layer_0001/next_inputs.pt + # after layer 4: cutoff = 2, delete layer_0002/next_inputs.pt + # Layer 3 retains its next_inputs.pt (the resume point for layer 4); + # layer 4 is the final layer and never has one. + for i in range(3): + assert not os.path.isfile(os.path.join(ckpt_dir, f"layer_{i:04d}", "next_inputs.pt")), ( + f"layer_{i:04d}/next_inputs.pt should have been pruned" + ) + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0003", "next_inputs.pt")) + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0004", "next_inputs.pt")) + + # Static per-layer files survive pruning. + for i in range(n_layers): + for fname in ("weights.pt", "quantizer_state.pt", "output_meta.pt"): + assert os.path.isfile(os.path.join(ckpt_dir, f"layer_{i:04d}", fname)), ( + f"layer_{i:04d}/{fname} should be retained" + ) + + +def test_resume_matches_full_run(monkeypatch, tmp_path): + """Resume after a simulated crash matches a full-run result. + + The crash is injected by raising in the calibration function partway + through, leaving the checkpoint in the same state a real crash would: + the manifest points to the last successfully saved layer and that + layer's next_inputs.pt is still on disk (pruning hasn't caught up to it). + """ + _register_test_discoverer(monkeypatch) - # Full reference run + # Reference: a complete run in its own directory. + ref_dir = str(tmp_path / "ref") ref_model, forward_loop = _make_model_and_forward(n_layers=3) - layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ref_dir) ref_weights = {n: p.clone() for n, p in ref_model.named_parameters()} - # Simulate crash after layer 0: truncate manifest - manifest_path = os.path.join(ckpt_dir, "manifest.json") - with open(manifest_path, "w") as f: - json.dump({"last_completed_layer": 0, "num_layers": 3}, f) + # Crash run: in a fresh dir, raise during the second layer's calibration so + # only layer 0 has been saved when the calibration loop unwinds. + resume_dir = str(tmp_path / "resume") + call_count = {"n": 0} + + def _crash_after_first_layer_calib(layer, forward_loop, **kwargs): + call_count["n"] += 1 + if call_count["n"] > 1: + raise RuntimeError("simulated crash before layer 1 completes") + _dummy_calib_func(layer, forward_loop, **kwargs) + + crash_model, crash_forward = _make_model_and_forward(n_layers=3) + with pytest.raises(RuntimeError, match="simulated crash"): + layerwise_calibrate( + crash_model, + crash_forward, + _crash_after_first_layer_calib, + checkpoint_dir=resume_dir, + ) + + # Mid-crash invariant: manifest at layer 0; its next_inputs.pt is intact + # (cutoff = -2, no prune has happened yet). + with open(os.path.join(resume_dir, "manifest.json")) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == 0 + assert manifest["num_layers"] == 3 + assert os.path.isfile(os.path.join(resume_dir, "layer_0000", "next_inputs.pt")) - # Resume from a fresh model + # Resume on a fresh model in the same checkpoint dir. resumed_model, forward_loop = _make_model_and_forward(n_layers=3) - layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=resume_dir) for name, ref_param in ref_weights.items(): resumed_param = dict(resumed_model.named_parameters())[name] From e84747aaea3d32a4f7bd139aee160a4e61df5cc6 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 11 May 2026 21:49:13 +0000 Subject: [PATCH 2/4] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- .../torch/quantization/utils/layerwise_calib.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index bdd8bce681d..8b820bf9568 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -475,7 +475,7 @@ def _read_manifest(checkpoint_dir: str) -> dict | None: def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: - """Atomically write manifest.json — the single commit point per iteration.""" + """Atomically write manifest.json.""" path = os.path.join(checkpoint_dir, "manifest.json") tmp = path + ".tmp" with open(tmp, "w") as f: @@ -576,11 +576,10 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint return cls(checkpoint_dir, num_layers, start_layer=start) def setup_resume(self, layers: nn.ModuleList) -> list | None: - """Load ``output_meta`` for skip layers 0..K-1, return ``next_inputs`` for layer K. + """Load output_meta for skip layers 0..K-1, return next_inputs for layer K.. Sets ``output_meta`` on each already-calibrated layer so that - skip mode can produce correctly shaped dummy outputs. The activations - are read from ``layer_{K-1:04d}/next_inputs.pt``. + skip mode can produce correctly shaped dummy outputs. """ if self.start_layer == 0: return None @@ -597,13 +596,10 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None: meta = _remap_output_metadata_device(meta, layer_device) layers[i]._layerwise_calib.output_meta = meta - next_inputs_path = os.path.join( - _layer_dir(self.checkpoint_dir, last_ckpt), "next_inputs.pt" - ) + d = _layer_dir(self.checkpoint_dir, last_ckpt) + next_inputs_path = os.path.join(d, "next_inputs.pt") if not os.path.isfile(next_inputs_path): - raise FileNotFoundError( - f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}: {next_inputs_path}" - ) + raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}") # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False) resume_device = get_module_device(layers[self.start_layer]) From 256282c7e9e762edec1d929cb7a09e1d05c2067d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 11 May 2026 21:50:56 +0000 Subject: [PATCH 3/4] stray Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils/layerwise_calib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index 8b820bf9568..8038620ee62 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -576,7 +576,7 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint return cls(checkpoint_dir, num_layers, start_layer=start) def setup_resume(self, layers: nn.ModuleList) -> list | None: - """Load output_meta for skip layers 0..K-1, return next_inputs for layer K.. + """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. Sets ``output_meta`` on each already-calibrated layer so that skip mode can produce correctly shaped dummy outputs. From 92e144d92c168f88821d0d7fa02c266b027ccad0 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 11 May 2026 23:00:24 +0000 Subject: [PATCH 4/4] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 4 +--- .../torch/quantization/utils/layerwise_calib.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f8ed63a051a..fe4c3f77ce6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1648,9 +1648,7 @@ def layerwise_calibrate( If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints are saved after each layer completes. On restart, calibration resumes from - the last completed layer. To bound disk usage, only the two most recent - layers' ``next_inputs.pt`` are retained — older copies are pruned after - each successful save. + the last completed layer. """ checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index 8038620ee62..c495aafa53b 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -23,7 +23,6 @@ from __future__ import annotations -import contextlib import json import os import shutil @@ -641,6 +640,12 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") + def prune_old_next_inputs(self, layer_idx: int) -> None: + """Delete the next_inputs.pt two layers back, keeping only the last two on disk.""" + old = os.path.join(_layer_dir(self.checkpoint_dir, layer_idx - 2), "next_inputs.pt") + if os.path.isfile(old): + os.remove(old) + def save( self, layer_idx: int, @@ -687,13 +692,7 @@ def save( self.num_layers, ) - # Prune the next_inputs.pt for the layer two back — we only need the - # last two to support resume from the most recent commit point. Done - # after the manifest write so a crash here is harmless. - stale = layer_idx - 2 - if stale >= 0: - with contextlib.suppress(FileNotFoundError): - os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")) + self.prune_old_next_inputs(layer_idx) suffix = " (final)" if next_layer_inputs is None else "" print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}")