Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions docs/on_demand_checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# On-Demand Checkpointing

On-demand checkpointing enables graceful checkpoint-and-exit when termination
signals are received during training. It is designed for environments like
OpenShift AI and KubeFlow where training jobs can be preempted at any time.

## How It Works

When enabled, the system installs signal handlers in the parent (launcher)
process that catch termination signals before the hard SIGKILL. When a signal
arrives:

1. The parent writes a trigger file to `/dev/shm` (a fast, node-local tmpfs).
2. Worker processes check for the trigger file at multiple synchronization
points during each training step.
3. Workers coordinate via `all_reduce` so that if any rank on any node
detects the trigger, all ranks agree to save.
4. A full-state checkpoint (model + optimizer + LR scheduler) is saved.
5. Workers exit cleanly.

On resume, the training loop detects the mid-epoch checkpoint, restores the
full training state, and fast-forwards to the exact step where training was
interrupted.

## Signals Handled

The following signals are intercepted (SIGKILL cannot be caught):

| Signal | Source |
|--------|--------|
| SIGTERM | Kubernetes graceful shutdown (default) |
| SIGINT | Ctrl-C / some job controllers |
| SIGUSR1 | Custom preemption controllers |
| SIGUSR2 | Custom preemption controllers |
| SIGXCPU | CPU time limit exceeded (resource quotas) |
| SIGHUP | Terminal disconnect / some eviction paths |

## Usage

### Python API

```python
from instructlab.training.config import TorchrunArgs, TrainingArgs
from instructlab.training import run_training

torch_args = TorchrunArgs(
nproc_per_node=8,
nnodes=1,
node_rank=0,
rdzv_id=12345,
rdzv_endpoint="127.0.0.1:29500",
)

train_args = TrainingArgs(
model_path="Qwen/Qwen2-1.5B-Instruct",
data_path="./data.jsonl",
data_output_dir="./processed",
ckpt_output_dir="./checkpoints",
num_epochs=3,
on_demand_checkpointing=True, # Enable the feature
# ... other training args
)

run_training(torch_args, train_args)
```

### CLI

```bash
torchrun --nproc-per-node=8 \
instructlab/training/main_ds.py \
--model_name_or_path Qwen/Qwen2-1.5B-Instruct \
--data_path ./data.jsonl \
--output_dir ./checkpoints \
--on_demand_checkpointing \
...
```

## Resume Behavior

When a checkpoint saved by on-demand checkpointing is found in the output
directory, the training loop automatically:

1. Loads the full optimizer and LR scheduler state from the checkpoint.
2. Reads `global_step` from the checkpoint metadata to determine where
training was interrupted.
3. Resumes at the **same epoch** and fast-forwards to the exact step,
skipping already-completed batches.

This differs from epoch-boundary checkpoints, which resume at the start of
the next epoch.

### Checkpoint Metadata

On-demand checkpoints store additional metadata compared to epoch-boundary
checkpoints:

```json
{
"current_epoch": 0,
"samples_seen": 144,
"global_step": 19
}
```

The `global_step` field is what distinguishes an on-demand checkpoint from an
epoch-boundary one. When present, the resume logic keeps `current_epoch`
unchanged and sets `last_step = global_step` to enable fast-forwarding.

## Multi-Node Training

The trigger file mechanism works correctly across multiple nodes:

- The trigger file lives on `/dev/shm`, which is node-local. Each node's
parent process writes its own trigger file when it receives a signal.
- Workers use `all_reduce(MAX)` to synchronize: if any rank on any node
detects a trigger, all ranks on all nodes agree to save.
- The checkpoint itself is saved to the shared filesystem (the configured
`ckpt_output_dir`), accessible by all nodes on resume.

## Stale Trigger Files

If a previous training run was killed before workers could clean up the
trigger file, the new run's `ParentSignalHandler` detects and removes it
during initialization. This prevents a new job from immediately
checkpointing and exiting due to a leftover trigger from a prior run.

## Manually Triggering a Checkpoint

You can trigger a checkpoint-and-exit without sending a signal by writing
the trigger file directly. This is useful for debugging, testing, or
integration with custom orchestration that doesn't use Unix signals.

The trigger file is always at a fixed path. To trigger a checkpoint
(e.g. via `kubectl exec` into the training pod):

```bash
touch /dev/shm/instructlab_checkpoint_requested
```

Workers check for the trigger file at each synchronization point in the
training loop (multiple times per step). Once any rank on any node detects
it, all ranks coordinate via `all_reduce` to save a checkpoint and exit.

You only need to write the file on **one node** — the `all_reduce` ensures
all nodes participate even if they don't see the file locally.

From Python:

```python
from instructlab.training.on_demand_checkpoint import write_trigger_file

write_trigger_file()
```

## Kubernetes / OpenShift Configuration

To give workers enough time to save a checkpoint before the hard SIGKILL,
increase `terminationGracePeriodSeconds` in your pod spec:

```yaml
spec:
terminationGracePeriodSeconds: 300 # 5 minutes
```

The default of 30 seconds may not be enough for large models. The checkpoint
save time depends on model size, number of GPUs, and filesystem speed.
51 changes: 43 additions & 8 deletions src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"""

# Standard
from dataclasses import dataclass
from collections.abc import Callable
from dataclasses import dataclass, field
import logging

# Third Party
Expand All @@ -33,6 +34,7 @@ class BatchMetrics:
accumulated_aux_loss: torch.Tensor | None
grad_accum_steps: int
num_minibatches: int
interrupted: bool = field(default=False)


class BatchLossManager:
Expand Down Expand Up @@ -62,12 +64,22 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int):
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)

def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
def process_batch(
self,
batch: list[CollatedItem],
interrupt_check: Callable[[], bool] | None = None,
) -> tuple[BatchMetrics, float]:
"""
Process a batch of minibatches, computing losses and accumulating gradients.

Args:
batch: List of minibatches to process
interrupt_check: Optional callback invoked at three points per
minibatch: before forward, before backward, and after
backward. If it returns ``True`` at any point, processing
stops early and ``BatchMetrics.interrupted`` is set. Used by
on-demand checkpointing to react as quickly as possible
instead of waiting for the full optimizer step.

Returns:
tuple: (BatchMetrics, average_loss_across_ranks)
Expand All @@ -82,9 +94,15 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
accumulated_loss = 0.0
accumulated_aux_loss = 0.0
grad_accum_steps = 0
interrupted = False

# process each minibatch
for mb in batch:
# Check for on-demand checkpoint before forward
if interrupt_check is not None and interrupt_check():
interrupted = True
break

# extract minibatch-specific info
micro_batch_size = mb["num_samples"]
total_length = mb["total_length"]
Expand All @@ -96,10 +114,16 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
# prepare model inputs
model_inputs = self._prepare_model_inputs(mb)

# compute loss and backward pass
# compute loss (forward pass)
scaled_loss, raw_losses = self.model.compute_loss(
model_inputs, self.world_size, batch_num_loss_counted_tokens
)

# Check for on-demand checkpoint before backward
if interrupt_check is not None and interrupt_check():
interrupted = True
break

self.accelerator.backward(scaled_loss)

# accumulate losses
Expand All @@ -108,6 +132,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
if raw_losses.aux_loss is not None:
accumulated_aux_loss += raw_losses.aux_loss

# Check for on-demand checkpoint after backward
if interrupt_check is not None and interrupt_check():
interrupted = True
break

# reduce metrics across ranks
batch_total_samples, batch_total_length = self._reduce_metrics(
batch_total_samples, batch_total_length
Expand All @@ -127,6 +156,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
accumulated_aux_loss=accumulated_aux_loss,
grad_accum_steps=grad_accum_steps,
num_minibatches=num_minibatches,
interrupted=interrupted,
)

return metrics, avg_loss_across_ranks
Expand Down Expand Up @@ -165,8 +195,8 @@ def _reduce_metrics(

def _compute_average_loss(
self,
accumulated_loss: torch.Tensor,
accumulated_aux_loss: torch.Tensor | None,
accumulated_loss: torch.Tensor | float,
accumulated_aux_loss: torch.Tensor | float | None,
batch_num_loss_counted_tokens: int,
) -> float:
"""Compute average loss across all ranks for metrics logging."""
Expand All @@ -177,11 +207,16 @@ def _compute_average_loss(
if accumulated_aux_loss is not None:
total_batch_loss += accumulated_aux_loss

# Extract scalar value — accumulated_loss may be a plain float if the
# minibatch loop was interrupted before any forward pass completed.
if isinstance(total_batch_loss, torch.Tensor):
loss_value = total_batch_loss.detach().item()
else:
loss_value = float(total_batch_loss)

# reduce across ranks
avg_loss_across_ranks = self.accelerator.reduce(
torch.tensor(
total_batch_loss.detach().item(), device=self.accelerator.device
),
torch.tensor(loss_value, device=self.accelerator.device),
reduction="mean",
).item()

Expand Down
13 changes: 13 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel):
description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.",
)

on_demand_checkpointing: bool = Field(
default=False,
description=(
"Enable on-demand full-state checkpointing triggered by Unix signals. "
"When enabled, the parent process intercepts termination signals "
"(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a "
"trigger file to /dev/shm. Worker processes check for this trigger "
"after each minibatch backward pass and collectively save a distributed "
"checkpoint before exiting gracefully. Designed for OpenShift AI / "
"KubeFlow training jobs where preemption signals must be handled."
),
)

@model_validator(mode="after")
def validate_validation_config(self):
if not 0.0 <= self.validation_split < 1.0:
Expand Down
Loading
Loading