Skip to content

Add on-demand full-state checkpointing for OpenShift AI / KubeFlow preemption#686

Open
RobotSail wants to merge 14 commits intomainfrom
claude/on-demand-checkpointing-dCsmo
Open

Add on-demand full-state checkpointing for OpenShift AI / KubeFlow preemption#686
RobotSail wants to merge 14 commits intomainfrom
claude/on-demand-checkpointing-dCsmo

Conversation

@RobotSail
Copy link
Copy Markdown
Member

@RobotSail RobotSail commented Feb 26, 2026

Implements signal-driven checkpoint-and-exit for distributed training jobs
running in OpenShift AI as KubeFlow training jobs or multi-node bare metal.

When on_demand_checkpointing=True is set in TrainingArgs:

  • Parent process (run_training) installs handlers for SIGTERM, SIGINT,
    SIGUSR1, SIGUSR2, SIGXCPU, and SIGHUP — covering all signals
    Kubernetes/OpenShift sends before the hard SIGKILL.
  • On signal receipt, a trigger file is atomically written to /dev/shm
    (tmpfs, shared within the pod, zero disk I/O).
  • Worker processes check for the trigger file after each optimizer step
    via an all_reduce(MAX) collective, ensuring global consensus across
    all ranks on all nodes.
  • When any rank detects the trigger, all ranks collectively save a
    full-state distributed checkpoint (model + optimizer + LR scheduler)
    then exit gracefully.
  • Parent waits up to 300s for workers to complete the checkpoint before
    proceeding with normal shutdown.

https://claude.ai/code/session_01HSxsk7SnMULJxy7uafe7t3

Summary by CodeRabbit

  • New Features
    • New training option and CLI flag to enable on-demand, signal-triggered full-state checkpointing for preemption.
    • Coordinated parent/worker checkpoint orchestration with signal handling, trigger-file coordination, graceful worker wait, and improved shutdown messaging.
    • Worker-side utilities to detect, save, and clear on-demand checkpoint requests.
    • Batch processing can be interrupted mid-minibatch and now reports an interrupted flag in batch metrics.
    • Mid-epoch checkpoints now record/resume using an optional global-step metadata field.

…eemption

Implements signal-driven checkpoint-and-exit for distributed training jobs
running in OpenShift AI as KubeFlow training jobs or multi-node bare metal.

When `on_demand_checkpointing=True` is set in TrainingArgs:

- Parent process (run_training) installs handlers for SIGTERM, SIGINT,
  SIGUSR1, SIGUSR2, SIGXCPU, and SIGHUP — covering all signals
  Kubernetes/OpenShift sends before the hard SIGKILL.
- On signal receipt, a trigger file is atomically written to /dev/shm
  (tmpfs, shared within the pod, zero disk I/O).
- Worker processes check for the trigger file after each optimizer step
  via an all_reduce(MAX) collective, ensuring global consensus across
  all ranks on all nodes.
- When any rank detects the trigger, all ranks collectively save a
  full-state distributed checkpoint (model + optimizer + LR scheduler)
  then exit gracefully.
- Parent waits up to 300s for workers to complete the checkpoint before
  proceeding with normal shutdown.

https://claude.ai/code/session_01HSxsk7SnMULJxy7uafe7t3
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Feb 26, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an opt‑in on‑demand, signal‑triggered full‑state checkpointing mode: new TrainingArgs flag and CLI option, parent-side signal handler to create a shared trigger, worker-side consensus check and save-on-demand flow, and integration points in minibatch processing to interrupt and persist training state.

Changes

Cohort / File(s) Summary
Configuration
src/instructlab/training/config.py
Adds public boolean field on_demand_checkpointing: bool = False to TrainingArgs with description.
Training Orchestration / CLI
src/instructlab/training/main_ds.py
Adds --on_demand_checkpointing CLI flag; extends train(...) signature with on_demand_checkpointing; threads flag into subprocess launch (run_training) and appends CLI arg; parent installs/uninstalls ParentSignalHandler, waits for worker checkpoint saves, refines subprocess termination and error messages; integrates on‑demand checks into minibatch/optimizer flow and adjusts samples_seen increment timing.
On‑Demand Checkpointing Module
src/instructlab/training/on_demand_checkpoint.py
New module implementing /dev/shm trigger-file orchestration: write_trigger_file, trigger_file_exists, remove_trigger_file; ParentSignalHandler to atomically write trigger on signals; check_checkpoint_requested (local check + dist.all_reduce consensus); save_on_demand_checkpoint wrapper calling existing save with full_state=True/hf_format=True.
Batch Processing
src/instructlab/training/batch_loss_manager.py
Adds interrupted: bool to BatchMetrics; BatchLossManager.process_batch(...) gains optional interrupt_check callback and invokes it at three points per minibatch; returns BatchMetrics.interrupted; adjusts average-loss handling to accept float or Tensor from early interruptions.
Checkpoint Utilities
src/instructlab/training/utils.py
save_checkpoint(...) and save_full_state(...) accept optional global_step and include it in training_metadata.json when provided; load_latest_full_state(...) now resumes mid‑epoch when global_step is present by setting args.last_step and args.current_epoch accordingly; removed unconditional epoch increment prior to checkpoint branching.

Sequence Diagram

sequenceDiagram
    participant Parent as Parent Process
    participant Signal as ParentSignalHandler
    participant Worker as Worker Process(es)
    participant Trigger as Trigger File (/dev/shm)
    participant Dist as Distributed Backend
    participant Checkpoint as Checkpoint Storage

    Note over Parent,Worker: On‑demand checkpoint flow

    Parent->>Signal: install()
    Worker->>Worker: training loop -> process_batch(interrupt_check)

    Parent->>Parent: receives termination signal
    Parent->>Signal: handler invoked
    Signal->>Trigger: write_trigger_file(job_id)

    Worker->>Trigger: trigger_file_exists()
    Worker->>Dist: all_reduce(MAX, local_flag)
    Dist-->>Worker: consensus_flag

    alt consensus_flag == true
        Worker->>Checkpoint: save_on_demand_checkpoint(full_state=True)
        Checkpoint-->>Worker: saved
        Worker->>Trigger: remove_trigger_file()
        Worker->>Worker: exit early
    end

    Parent->>Parent: wait for workers (timeout)
    Parent->>Signal: uninstall()
    Parent->>Parent: exit
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 A tiny tap in shared RAM's nest,
Ranks whisper "save" and do their best.
Parent rings the bell, the file appears,
Workers tuck state, then calm their gears.
Hop! A checkpoint safe — carrots and cheers.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title directly and accurately summarizes the main change: adding on-demand full-state checkpointing functionality for OpenShift AI/KubeFlow preemption scenarios.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch claude/on-demand-checkpointing-dCsmo

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mergify mergify bot added the ci-failure label Feb 26, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
src/instructlab/training/on_demand_checkpoint.py (1)

225-229: Consider rank-gating the global-consensus log.

When a checkpoint is requested, every rank logs the same message. Logging only on rank 0 would reduce shutdown-time log bursts on large jobs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/instructlab/training/on_demand_checkpoint.py` around lines 225 - 229, The
log message is emitted by every rank when a checkpoint is requested; gate it to
only run on the main/rank-0 process to avoid log storms. Wrap the existing
logger.info block (the code that runs when requested is truthy) with a check for
the main process—e.g., if torch.distributed.is_initialized() and
torch.distributed.get_rank() == 0: or, if the project exposes a helper like
is_main_process(), use that—then call logger.info only inside that conditional
while leaving the checkpoint request flow unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 882-900: The code computes failure using process.poll() before
sending terminate()/kill(), so if the subprocess exits after forced shutdown the
failure status can be stale; update the logic inside the shutdown path in
main_ds.py to recompute process_code and failure after you perform
terminate()/kill() and any subsequent wait() calls (use process.wait with a
timeout then process.poll()), and then decide whether to log success or raise
based on the new failure value; apply the same fix for the second occurrence
referenced (the block around the later terminate/kill sequence) and reference
process.wait, process.poll, terminate(), kill(), and the logger.error messages
when updating the flow.
- Around line 821-833: The ParentSignalHandler is being instantiated without a
job identifier causing shared trigger files; update the instantiation to pass a
stable job id (e.g., use train_args.job_id or another unique training identifier
available in scope) so ParentSignalHandler(job_id=...) is used and the
handler.install() uses a namespaced trigger path; ensure the same job_id is
passed to any worker-side reader logic so trigger files live under a per-job
namespace instead of the global default.

---

Nitpick comments:
In `@src/instructlab/training/on_demand_checkpoint.py`:
- Around line 225-229: The log message is emitted by every rank when a
checkpoint is requested; gate it to only run on the main/rank-0 process to avoid
log storms. Wrap the existing logger.info block (the code that runs when
requested is truthy) with a check for the main process—e.g., if
torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: or, if
the project exposes a helper like is_main_process(), use that—then call
logger.info only inside that conditional while leaving the checkpoint request
flow unchanged.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f02ea6 and 4d82b3d.

📒 Files selected for processing (3)
  • src/instructlab/training/config.py
  • src/instructlab/training/main_ds.py
  • src/instructlab/training/on_demand_checkpoint.py

- Fix mypy error: properly type _original_handlers dict with _SignalHandler
  type alias instead of bare object
- Fix ruff/isort: remove duplicate comment, fix import ordering
- Namespace trigger file with rdzv_id as job_id so concurrent jobs sharing
  /dev/shm don't interfere with each other
- Recompute subprocess failure status after forced termination to avoid
  stale exit code
- Gate consensus log message to rank 0 to reduce log noise on large jobs
Move the checkpoint request check from after the full optimizer step to
after each minibatch's backward pass inside BatchLossManager.process_batch.
This ensures the system responds within one fwd+bwd cycle (~1-2s) even
when gradient accumulation spans many minibatches, giving more time to
save before Kubernetes sends SIGKILL after the grace period.

The check is passed as an optional interrupt_check callback to keep
checkpoint-specific logic out of BatchLossManager. When triggered, the
batch loop breaks early and the training loop saves the checkpoint
immediately, skipping the optimizer step to preserve the pre-step model
state for exact resumption.
When the training subprocess fails after an on-demand checkpoint signal
was received, the error message now includes guidance to increase
terminationGracePeriodSeconds or reduce fwd/bwd pass time so the
checkpoint check fires before SIGKILL arrives.
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 1160-1170: The help text for the "--on_demand_checkpointing"
argparse option is inaccurate: it says workers check "after each training step"
but the implementation triggers checks after each minibatch backward pass (see
BatchLossManager.process_batch). Update the parser.add_argument help string for
"--on_demand_checkpointing" to explicitly say the check happens after each
minibatch/backward pass (or "after each minibatch backward pass") and mention
that this is the granularity for checkpoint-trigger latency so the doc matches
the behavior in BatchLossManager.process_batch.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cffa29 and d089910.

📒 Files selected for processing (2)
  • src/instructlab/training/batch_loss_manager.py
  • src/instructlab/training/main_ds.py

Update --on_demand_checkpointing help text and TrainingArgs description
to accurately state that workers check for the trigger file after each
minibatch backward pass, not after each training step.
@mergify mergify bot added ci-failure and removed ci-failure labels Mar 2, 2026
Expand on-demand checkpointing to check for a trigger at five points:
1. Before each minibatch forward pass
2. Before each minibatch backward pass
3. After each minibatch backward pass (existing)
4. Before the optimizer step
5. After the optimizer step

This minimizes the latency between a termination signal arriving and
the checkpoint being saved, which is critical when the SIGKILL grace
period is short (e.g. 30s on OpenShift/Kubernetes).

Also cleans up the save-and-exit logic in train() by extracting a
_save_and_exit() helper to eliminate three nearly identical blocks,
and fixes _compute_average_loss to handle the case where the
minibatch loop is interrupted before any forward pass completes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mergify mergify bot removed the ci-failure label Mar 20, 2026
Two fixes to the on-demand checkpointing feature:

1. Stale trigger file cleanup: ParentSignalHandler.install() now checks
   for and removes any existing trigger file before installing signal
   handlers. If the file exists before handlers are installed, it's from
   a previous run that was killed before workers could clean it up.
   Prevents a new training job from immediately checkpointing and exiting.

2. Exact mid-epoch resume: save_on_demand_checkpoint() now persists
   global_step in the checkpoint metadata alongside current_epoch and
   samples_seen. On resume, load_latest_full_state() detects the
   global_step field and sets last_step accordingly, so the training
   loop fast-forwards to the exact step within the epoch. Without this,
   mid-epoch checkpoints would skip to the next epoch on resume, losing
   remaining steps.

Tested with Qwen2-1.5B-Instruct on 2 GPUs: interrupted at step 19/25,
checkpoint saved with global_step=19, resumed and completed steps 20-25.
@mergify mergify bot added the ci-failure label Mar 24, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/instructlab/training/utils.py (1)

897-897: ⚠️ Potential issue | 🔴 Critical

torch.load without weights_only=False will fail on PyTorch 2.6+.

PyTorch 2.6.0 (released January 29, 2025) changed the weights_only default from False to True. Since training_metadata.json contains a Python dictionary (not tensors or state_dicts), loading with the restricted unpickler may fail or encounter compatibility issues.

More importantly, using JSON for simple metadata is the correct approach: the data is not tensors, the filename extension .json should accurately reflect the format, and JSON provides human-readable serialization without pickle concerns.

🔧 Proposed fix using actual JSON format
+import json
+
 def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int, global_step: int | None = None):
     # ... existing code ...
     
     # save metadata file for current training status
     if accelerator.is_main_process:
         metadata = {"current_epoch": epoch, "samples_seen": samples_seen}
         # Save global_step when provided (on-demand mid-epoch checkpoints)
         # so that resume can fast-forward to the exact training step.
         if global_step is not None:
             metadata["global_step"] = global_step
-        torch.save(metadata, output_dir / "training_metadata.json")
+        with open(output_dir / "training_metadata.json", "w") as f:
+            json.dump(metadata, f)
         log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True)
 def load_latest_full_state(args, accelerator) -> None:
     # ... existing code ...
     
-    training_metadata = torch.load(latest / "training_metadata.json")
+    with open(latest / "training_metadata.json", "r") as f:
+        training_metadata = json.load(f)
     log_rank_0(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/instructlab/training/utils.py` at line 897, The code uses torch.load to
read "training_metadata.json" which is wrong for JSON and breaks on PyTorch
2.6+; replace the torch.load call that assigns training_metadata with a proper
JSON load (open the file at latest / "training_metadata.json" in text mode and
parse it with json.load or json.loads(latest.read_text())) so training_metadata
becomes the dict from the JSON file, and remove the torch.load usage; refer to
the training_metadata assignment and the latest Path variable when making the
change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 1178-1190: The help text for the --on_demand_checkpointing
argparse option (parser.add_argument in main_ds.py) incorrectly says "five
synchronization points per step"; update the help string to accurately describe
the implemented checks (3 checks per minibatch: before forward, before backward,
after backward, multiplied by N minibatches, plus 2 checks around the optimizer
step, i.e. 3*N + 2 per step) and rephrase it succinctly to reflect this variable
count and intent (on-demand full-state checkpointing via trigger file in
/dev/shm at those synchronization points).

---

Outside diff comments:
In `@src/instructlab/training/utils.py`:
- Line 897: The code uses torch.load to read "training_metadata.json" which is
wrong for JSON and breaks on PyTorch 2.6+; replace the torch.load call that
assigns training_metadata with a proper JSON load (open the file at latest /
"training_metadata.json" in text mode and parse it with json.load or
json.loads(latest.read_text())) so training_metadata becomes the dict from the
JSON file, and remove the torch.load usage; refer to the training_metadata
assignment and the latest Path variable when making the change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9139806f-03c1-46c5-9474-721dbf73e883

📥 Commits

Reviewing files that changed from the base of the PR and between d21bf12 and 3794d5e.

📒 Files selected for processing (3)
  • src/instructlab/training/main_ds.py
  • src/instructlab/training/on_demand_checkpoint.py
  • src/instructlab/training/utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/instructlab/training/on_demand_checkpoint.py

@mergify mergify bot removed the ci-failure label Mar 24, 2026
@mergify mergify bot added the ci-failure label Mar 24, 2026
@mergify mergify bot removed the ci-failure label Mar 24, 2026
Drop the job_id suffix from the trigger file path. The file is now
always /dev/shm/instructlab_checkpoint_requested with no suffix.

The namespacing was defensive against concurrent jobs sharing /dev/shm,
but in practice Kubernetes pods each get their own /dev/shm. This makes
manual triggering trivial: touch /dev/shm/instructlab_checkpoint_requested
@mergify mergify bot added the ci-failure label Mar 24, 2026
@mergify mergify bot removed the ci-failure label Mar 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants