diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md new file mode 100644 index 00000000..b21d71d5 --- /dev/null +++ b/docs/on_demand_checkpointing.md @@ -0,0 +1,167 @@ +# On-Demand Checkpointing + +On-demand checkpointing enables graceful checkpoint-and-exit when termination +signals are received during training. It is designed for environments like +OpenShift AI and KubeFlow where training jobs can be preempted at any time. + +## How It Works + +When enabled, the system installs signal handlers in the parent (launcher) +process that catch termination signals before the hard SIGKILL. When a signal +arrives: + +1. The parent writes a trigger file to `/dev/shm` (a fast, node-local tmpfs). +2. Worker processes check for the trigger file at multiple synchronization + points during each training step. +3. Workers coordinate via `all_reduce` so that if any rank on any node + detects the trigger, all ranks agree to save. +4. A full-state checkpoint (model + optimizer + LR scheduler) is saved. +5. Workers exit cleanly. + +On resume, the training loop detects the mid-epoch checkpoint, restores the +full training state, and fast-forwards to the exact step where training was +interrupted. + +## Signals Handled + +The following signals are intercepted (SIGKILL cannot be caught): + +| Signal | Source | +|--------|--------| +| SIGTERM | Kubernetes graceful shutdown (default) | +| SIGINT | Ctrl-C / some job controllers | +| SIGUSR1 | Custom preemption controllers | +| SIGUSR2 | Custom preemption controllers | +| SIGXCPU | CPU time limit exceeded (resource quotas) | +| SIGHUP | Terminal disconnect / some eviction paths | + +## Usage + +### Python API + +```python +from instructlab.training.config import TorchrunArgs, TrainingArgs +from instructlab.training import run_training + +torch_args = TorchrunArgs( + nproc_per_node=8, + nnodes=1, + node_rank=0, + rdzv_id=12345, + rdzv_endpoint="127.0.0.1:29500", +) + +train_args = TrainingArgs( + model_path="Qwen/Qwen2-1.5B-Instruct", + data_path="./data.jsonl", + data_output_dir="./processed", + ckpt_output_dir="./checkpoints", + num_epochs=3, + on_demand_checkpointing=True, # Enable the feature + # ... other training args +) + +run_training(torch_args, train_args) +``` + +### CLI + +```bash +torchrun --nproc-per-node=8 \ + instructlab/training/main_ds.py \ + --model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --data_path ./data.jsonl \ + --output_dir ./checkpoints \ + --on_demand_checkpointing \ + ... +``` + +## Resume Behavior + +When a checkpoint saved by on-demand checkpointing is found in the output +directory, the training loop automatically: + +1. Loads the full optimizer and LR scheduler state from the checkpoint. +2. Reads `global_step` from the checkpoint metadata to determine where + training was interrupted. +3. Resumes at the **same epoch** and fast-forwards to the exact step, + skipping already-completed batches. + +This differs from epoch-boundary checkpoints, which resume at the start of +the next epoch. + +### Checkpoint Metadata + +On-demand checkpoints store additional metadata compared to epoch-boundary +checkpoints: + +```json +{ + "current_epoch": 0, + "samples_seen": 144, + "global_step": 19 +} +``` + +The `global_step` field is what distinguishes an on-demand checkpoint from an +epoch-boundary one. When present, the resume logic keeps `current_epoch` +unchanged and sets `last_step = global_step` to enable fast-forwarding. + +## Multi-Node Training + +The trigger file mechanism works correctly across multiple nodes: + +- The trigger file lives on `/dev/shm`, which is node-local. Each node's + parent process writes its own trigger file when it receives a signal. +- Workers use `all_reduce(MAX)` to synchronize: if any rank on any node + detects a trigger, all ranks on all nodes agree to save. +- The checkpoint itself is saved to the shared filesystem (the configured + `ckpt_output_dir`), accessible by all nodes on resume. + +## Stale Trigger Files + +If a previous training run was killed before workers could clean up the +trigger file, the new run's `ParentSignalHandler` detects and removes it +during initialization. This prevents a new job from immediately +checkpointing and exiting due to a leftover trigger from a prior run. + +## Manually Triggering a Checkpoint + +You can trigger a checkpoint-and-exit without sending a signal by writing +the trigger file directly. This is useful for debugging, testing, or +integration with custom orchestration that doesn't use Unix signals. + +The trigger file is always at a fixed path. To trigger a checkpoint +(e.g. via `kubectl exec` into the training pod): + +```bash +touch /dev/shm/instructlab_checkpoint_requested +``` + +Workers check for the trigger file at each synchronization point in the +training loop (multiple times per step). Once any rank on any node detects +it, all ranks coordinate via `all_reduce` to save a checkpoint and exit. + +You only need to write the file on **one node** — the `all_reduce` ensures +all nodes participate even if they don't see the file locally. + +From Python: + +```python +from instructlab.training.on_demand_checkpoint import write_trigger_file + +write_trigger_file() +``` + +## Kubernetes / OpenShift Configuration + +To give workers enough time to save a checkpoint before the hard SIGKILL, +increase `terminationGracePeriodSeconds` in your pod spec: + +```yaml +spec: + terminationGracePeriodSeconds: 300 # 5 minutes +``` + +The default of 30 seconds may not be enough for large models. The checkpoint +save time depends on model size, number of GPUs, and filesystem speed. diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index cc6da021..46e2af30 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -7,7 +7,8 @@ """ # Standard -from dataclasses import dataclass +from collections.abc import Callable +from dataclasses import dataclass, field import logging # Third Party @@ -33,6 +34,7 @@ class BatchMetrics: accumulated_aux_loss: torch.Tensor | None grad_accum_steps: int num_minibatches: int + interrupted: bool = field(default=False) class BatchLossManager: @@ -62,12 +64,22 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.local_rank: int = local_rank self.torch_device = torch.device("cuda", local_rank) - def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: + def process_batch( + self, + batch: list[CollatedItem], + interrupt_check: Callable[[], bool] | None = None, + ) -> tuple[BatchMetrics, float]: """ Process a batch of minibatches, computing losses and accumulating gradients. Args: batch: List of minibatches to process + interrupt_check: Optional callback invoked at three points per + minibatch: before forward, before backward, and after + backward. If it returns ``True`` at any point, processing + stops early and ``BatchMetrics.interrupted`` is set. Used by + on-demand checkpointing to react as quickly as possible + instead of waiting for the full optimizer step. Returns: tuple: (BatchMetrics, average_loss_across_ranks) @@ -82,9 +94,15 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_loss = 0.0 accumulated_aux_loss = 0.0 grad_accum_steps = 0 + interrupted = False # process each minibatch for mb in batch: + # Check for on-demand checkpoint before forward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # extract minibatch-specific info micro_batch_size = mb["num_samples"] total_length = mb["total_length"] @@ -96,10 +114,16 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] # prepare model inputs model_inputs = self._prepare_model_inputs(mb) - # compute loss and backward pass + # compute loss (forward pass) scaled_loss, raw_losses = self.model.compute_loss( model_inputs, self.world_size, batch_num_loss_counted_tokens ) + + # Check for on-demand checkpoint before backward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + self.accelerator.backward(scaled_loss) # accumulate losses @@ -108,6 +132,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # Check for on-demand checkpoint after backward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -127,6 +156,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_aux_loss=accumulated_aux_loss, grad_accum_steps=grad_accum_steps, num_minibatches=num_minibatches, + interrupted=interrupted, ) return metrics, avg_loss_across_ranks @@ -165,8 +195,8 @@ def _reduce_metrics( def _compute_average_loss( self, - accumulated_loss: torch.Tensor, - accumulated_aux_loss: torch.Tensor | None, + accumulated_loss: torch.Tensor | float, + accumulated_aux_loss: torch.Tensor | float | None, batch_num_loss_counted_tokens: int, ) -> float: """Compute average loss across all ranks for metrics logging.""" @@ -177,11 +207,16 @@ def _compute_average_loss( if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss + # Extract scalar value — accumulated_loss may be a plain float if the + # minibatch loop was interrupted before any forward pass completed. + if isinstance(total_batch_loss, torch.Tensor): + loss_value = total_batch_loss.detach().item() + else: + loss_value = float(total_batch_loss) + # reduce across ranks avg_loss_across_ranks = self.accelerator.reduce( - torch.tensor( - total_batch_loss.detach().item(), device=self.accelerator.device - ), + torch.tensor(loss_value, device=self.accelerator.device), reduction="mean", ).item() diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 911c3898..11182a0e 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel): description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.", ) + on_demand_checkpointing: bool = Field( + default=False, + description=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, the parent process intercepts termination signals " + "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " + "trigger file to /dev/shm. Worker processes check for this trigger " + "after each minibatch backward pass and collectively save a distributed " + "checkpoint before exiting gracefully. Designed for OpenShift AI / " + "KubeFlow training jobs where preemption signals must be handled." + ), + ) + @model_validator(mode="after") def validate_validation_config(self): if not 0.0 <= self.validation_split < 1.0: diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 072b27c6..6a4813d1 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -173,6 +173,7 @@ def train( accelerator: Accelerator, val_data_loader=None, validation_frequency=None, + on_demand_checkpointing: bool = False, ): model.train() @@ -183,6 +184,33 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") + # Import on-demand checkpointing utilities once if the feature is enabled + if on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ( + check_checkpoint_requested, + save_on_demand_checkpoint, + ) + + base_logger.info("On-demand checkpointing is enabled in worker process.") + + def _save_and_exit(checkpoint_location: str) -> None: + """Save an on-demand checkpoint and exit the training loop.""" + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + global_step=global_step, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved (%s). Exiting training.", + checkpoint_location, + ) + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -220,13 +248,28 @@ def train( continue start = time.time() - # Process the batch using the BatchLossManager + # Process the batch using the BatchLossManager. + # When on-demand checkpointing is enabled, pass a callback so + # the check runs after every minibatch backward rather than + # waiting for the full optimizer step. + _interrupt_check = ( + check_checkpoint_requested if on_demand_checkpointing else None + ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( - batch + batch, interrupt_check=_interrupt_check ) - # Update samples seen - samples_seen += batch_metrics.total_samples + # If the batch was interrupted by an on-demand checkpoint + # request, save immediately and exit — skip the optimizer step + # since we want to preserve the pre-step model state for + # exact resumption. + if batch_metrics.interrupted: + _save_and_exit("during minibatch processing") + return + + if on_demand_checkpointing and check_checkpoint_requested(): + _save_and_exit("before optimizer step") + return base_logger.info( f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {avg_loss_across_ranks:.6f}, grad_accum_steps = {batch_metrics.grad_accum_steps}" @@ -235,6 +278,13 @@ def train( # Take optimizer step after all minibatches accelerator.take_optimizer_step() + # Update samples seen after the optimizer step has been applied + samples_seen += batch_metrics.total_samples + + if on_demand_checkpointing and check_checkpoint_requested(): + _save_and_exit("after optimizer step") + return + if local_rank == 0: elapsed_time = time.time() - start overall_throughput = batch_metrics.total_samples / elapsed_time @@ -561,6 +611,7 @@ def main(args): accelerator=accelerator, val_data_loader=val_loader, validation_frequency=validation_frequency, + on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False), ) dist.barrier() @@ -791,7 +842,26 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + if train_args.on_demand_checkpointing: + command.append("--on_demand_checkpointing") + logger.info("Running training command as subprocess: %s", " ".join(command)) + + # --- On-demand checkpointing: install signal handlers in the parent --- + signal_handler = None + if train_args.on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ParentSignalHandler + + # Use rdzv_id to namespace the trigger file so concurrent jobs + # sharing /dev/shm don't interfere with each other. + signal_handler = ParentSignalHandler() + signal_handler.install() + logger.info( + "On-demand checkpointing is ENABLED. " + "Termination signals will trigger a full-state checkpoint before exit.", + ) + process = None interrupt: KeyboardInterrupt | Exception | None = None failure = False @@ -811,36 +881,85 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: interrupt = e finally: if "process" not in locals() or process is None: + if signal_handler is not None: + signal_handler.uninstall() return - # wait for the process to exit so we can properly read the exit code - process.wait(timeout=60) - process_code = process.poll() - failure = process_code != 0 - - if not failure: - logger.info("Operation completed successfully! 🎉") - else: - logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" + # If a signal was caught by the on-demand checkpoint handler, give + # the workers time to detect the trigger file and save a checkpoint + # before we start sending our own signals to the subprocess. + if signal_handler is not None and signal_handler.signal_received is not None: + logger.info( + "On-demand checkpoint: signal %s received. Waiting for workers to " + "save checkpoint before proceeding with shutdown...", + signal_handler.signal_received.name, ) + # Give workers generous time to complete the checkpoint save. + # The workers will exit on their own after saving. + try: + process.wait(timeout=300) + except subprocess.TimeoutExpired: + logger.warning( + "On-demand checkpoint: workers did not finish within 300s. " + "Proceeding with shutdown." + ) - process.terminate() + # wait for the process to exit so we can properly read the exit code try: - logger.info("Waiting for process to exit, 60s...") process.wait(timeout=60) except subprocess.TimeoutExpired: + pass + process_code = process.poll() + + if process_code is not None and process_code == 0: + logger.info("Operation completed successfully!") + elif process_code is None: + logger.error("Training subprocess has not exited yet. Sending SIGTERM.") + process.terminate() + try: + logger.info("Waiting for process to exit, 60s...") + process.wait(timeout=60) + except subprocess.TimeoutExpired: + logger.error( + "Training subprocess did not terminate before timeout, sending SIGKILL." + ) + process.kill() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + else: logger.error( - "Training subprocess did not terminate before timeout, sending SIGKILL." + "Training subprocess exited with code %d.", + process_code, ) - process.kill() + + # Recompute final exit status after any forced shutdown + process_code = process.poll() + failure = process_code is None or process_code != 0 + + if signal_handler is not None: + signal_handler.uninstall() if interrupt: raise interrupt if failure: - raise RuntimeError( - "Suffered a failure during distributed training. Please see the training logs for more context." - ) + msg = "Suffered a failure during distributed training. Please see the training logs for more context." + if ( + signal_handler is not None + and signal_handler.signal_received is not None + ): + msg += ( + f"\n\nNote: signal {signal_handler.signal_received.name} was" + " received and on-demand checkpointing was enabled, but the" + " training subprocess did not exit cleanly. This usually" + " means the process was killed (SIGKILL) before the" + " checkpoint could be saved. To fix this, increase" + " terminationGracePeriodSeconds in your pod spec to give" + " workers more time, or reduce the model's forward/backward" + " pass time so the checkpoint check fires sooner." + ) + raise RuntimeError(msg) if __name__ == "__main__": @@ -1045,6 +1164,18 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) + parser.add_argument( + "--on_demand_checkpointing", + action="store_true", + default=False, + help=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, workers check for a trigger file in /dev/shm at multiple " + "synchronization points (three times per minibatch and twice around the " + "optimizer step) and collectively save a distributed checkpoint before " + "exiting. Designed for OpenShift AI / KubeFlow preemption handling." + ), + ) parser.add_argument( "--use_liger", action="store_true", diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py new file mode 100644 index 00000000..be1771b5 --- /dev/null +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +On-demand checkpointing for distributed training. + +This module enables graceful checkpoint-and-exit when termination signals are +received. It is designed for environments like OpenShift AI / KubeFlow where +training jobs can be preempted at any time and the platform sends Unix signals +before killing the pod. + +Architecture +------------ +There are two sides to this feature: + +**Parent process** (``run_training`` in ``main_ds.py``): + Installs signal handlers that catch every signal OpenShift / Kubernetes can + send before a SIGKILL. When a signal arrives the handler writes a small + *trigger file* to ``/dev/shm`` (a tmpfs shared between containers in the + same pod). Because ``/dev/shm`` is node-local, every worker on the **same + node** can see the file instantly with zero network I/O. + +**Worker processes** (torchrun children): + The training loop calls ``check_checkpoint_requested()`` at multiple + synchronization points per training step (three per minibatch plus + two around the optimizer step), allowing the system to react as + quickly as possible to termination signals: + + 1. **Before each minibatch forward pass** — no partial computation; + the current state is saved as-is. + 2. **Before each minibatch backward pass** — the forward result is + discarded; the pre-step state is saved. + 3. **After each minibatch backward pass** — gradients are computed but + not yet applied; the pre-step state is saved (gradients will be + recomputed on resume). + 4. **Before the optimizer step** — all minibatches are done and + gradients are ready, but the step is skipped; the pre-step state + is saved. + 5. **After the optimizer step** — the step has been applied; + ``samples_seen`` is updated and the post-step state is saved. + + Each rank checks its local ``/dev/shm`` for the trigger file, converts + the boolean to a tensor, and does an ``all_reduce(MAX)`` so that if + *any* rank on *any* node detected the trigger, *every* rank agrees to + save a checkpoint. This works correctly in multi-node training because + all_reduce is a global collective. + +Signals handled +--------------- +We intercept every signal that Kubernetes / OpenShift can deliver before the +hard SIGKILL (which cannot be caught): + +* **SIGTERM** – the standard graceful-shutdown signal. Kubernetes sends this + first (configurable via ``terminationGracePeriodSeconds``). +* **SIGINT** – sent on Ctrl-C or by some job controllers. +* **SIGUSR1 / SIGUSR2** – commonly used by batch schedulers and custom + preemption controllers to signal upcoming eviction. +* **SIGXCPU** – sent when CPU time limits are exceeded (relevant for jobs + with resource quotas). +* **SIGHUP** – sent when the controlling terminal disconnects; some + container runtimes forward this on pod eviction. +""" + +# Standard +from pathlib import Path +from typing import Callable, Optional, Union +import logging +import os +import signal +import tempfile +import types + +# Third Party +import torch +import torch.distributed as dist + +# Type alias matching the return type of signal.getsignal(). +_SignalHandler = Union[ + Callable[[int, Optional[types.FrameType]], None], int, signal.Handlers, None +] + +logger = logging.getLogger("instructlab.training") + +# --------------------------------------------------------------------------- +# Trigger file helpers +# --------------------------------------------------------------------------- + +# The trigger file lives in /dev/shm which is a tmpfs (RAM-backed filesystem). +# It is: +# 1. Extremely fast (no disk I/O). +# 2. Shared between all containers in the same Kubernetes pod. +# 3. Automatically cleaned up when the pod is destroyed. +_TRIGGER_DIR = Path("/dev/shm") +_TRIGGER_FILENAME = "instructlab_checkpoint_requested" + + +def _get_trigger_path() -> Path: + """Return the path to the checkpoint trigger file.""" + return _TRIGGER_DIR / _TRIGGER_FILENAME + + +def write_trigger_file() -> Path: + """Create the trigger file that tells workers to checkpoint. + + This is called from the *parent* process signal handler. + Returns the path that was written. + """ + path = _get_trigger_path() + # Use a atomic write via tempfile + rename to avoid partial reads. + fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") + try: + os.write(fd, b"1") + finally: + os.close(fd) + os.rename(tmp, path) + logger.info( + "On-demand checkpoint trigger file written: %s", + path, + ) + return path + + +def trigger_file_exists() -> bool: + """Check whether the trigger file exists (worker-side).""" + return _get_trigger_path().exists() + + +def remove_trigger_file() -> None: + """Remove the trigger file after the checkpoint has been saved.""" + path = _get_trigger_path() + try: + path.unlink(missing_ok=True) + except OSError: + pass + + +# --------------------------------------------------------------------------- +# Parent-side signal handling +# --------------------------------------------------------------------------- + +# Signals that OpenShift / Kubernetes / batch schedulers may send before +# the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. +_CATCHABLE_SIGNALS = ( + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths +) + + +class ParentSignalHandler: + """Installs signal handlers in the parent (launcher) process. + + When any of the catchable signals fire, the handler: + 1. Writes the trigger file to ``/dev/shm``. + 2. Records that a signal was received (so the caller can decide to + wait for the child process to finish checkpointing). + + The handler is idempotent – multiple signals will not create multiple + trigger files. + + """ + + def __init__(self): + self.signal_received: Optional[signal.Signals] = None + self._original_handlers: dict[signal.Signals, _SignalHandler] = {} + self._trigger_written = False + + def install(self) -> None: + """Register signal handlers for all catchable signals.""" + # Clear any stale trigger file from a previous run. If the file + # exists before we've even installed signal handlers, it cannot + # be from this job — it's left over from a prior run that was + # killed before the workers could clean it up. + if trigger_file_exists(): + logger.info( + "On-demand checkpoint: clearing stale trigger file from " + "a previous run.", + ) + try: + remove_trigger_file() + except Exception: + logger.warning( + "On-demand checkpoint: failed to remove stale trigger file, " + "but continuing anyway.", + exc_info=True, + ) + + for sig in _CATCHABLE_SIGNALS: + try: + self._original_handlers[sig] = signal.getsignal(sig) + signal.signal(sig, self._handle) + except (OSError, ValueError): + # Some signals may not be available on all platforms + logger.debug("Could not install handler for %s", sig.name) + + logger.info( + "On-demand checkpoint signal handlers installed for: %s", + ", ".join(s.name for s in self._original_handlers), + ) + + def uninstall(self) -> None: + """Restore original signal handlers.""" + for sig, handler in self._original_handlers.items(): + try: + signal.signal(sig, handler) # type: ignore[arg-type] + except (OSError, ValueError): + pass + self._original_handlers.clear() + + def _handle(self, signum: int, _frame) -> None: + """Signal handler callback.""" + sig = signal.Signals(signum) + logger.info( + "On-demand checkpoint: received signal %s (%d). " + "Writing trigger file for workers to checkpoint before exit.", + sig.name, + signum, + ) + self.signal_received = sig + + if not self._trigger_written: + write_trigger_file() + self._trigger_written = True + + +# --------------------------------------------------------------------------- +# Worker-side synchronization +# --------------------------------------------------------------------------- + + +def check_checkpoint_requested() -> bool: + """Check across all ranks whether an on-demand checkpoint was requested. + + This function must be called by **all ranks** at the same point in the + training loop (it contains a collective all_reduce). + + Returns ``True`` if any rank detected the trigger file, meaning all + ranks should save a checkpoint. + """ + local_trigger = trigger_file_exists() + + # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY + # node saw the trigger, every rank gets True. + trigger_tensor = torch.tensor( + [1 if local_trigger else 0], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + dist.all_reduce(trigger_tensor, op=dist.ReduceOp.MAX) + + requested = trigger_tensor.item() > 0 + + if requested: + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) + # Clean up the trigger file so that if the process somehow + # continues, we don't save again immediately. + remove_trigger_file() + + return requested + + +def save_on_demand_checkpoint( + args, + accelerator, + model, + tokenizer, + samples_seen: int, + epoch: int, + global_step: int, + is_lora: bool, +) -> None: + """Save a full-state distributed checkpoint for on-demand resume. + + This is a thin wrapper that calls the existing ``save_checkpoint`` + utility with ``full_state=True`` so that optimizer + LR scheduler + state are also persisted, enabling exact training resumption. + + The ``global_step`` is saved to the checkpoint metadata so that + on resume the training loop can fast-forward to the exact step + within the epoch where training was interrupted. + """ + # First Party + from instructlab.training.utils import save_checkpoint + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: + logger.info( + "On-demand checkpoint: saving full-state checkpoint at " + "epoch=%d, global_step=%d, samples_seen=%d", + epoch, + global_step, + samples_seen, + ) + + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, + is_lora=is_lora, + full_state=True, + hf_format=True, + epoch=epoch, + global_step=global_step, + ) + + if local_rank == 0: + logger.info("On-demand checkpoint: checkpoint saved successfully.") diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index fc31858e..5a4bd8e0 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -794,6 +794,7 @@ def save_checkpoint( epoch: int = None, hf_format: bool = True, full_state: bool = False, + global_step: int | None = None, ) -> None: if hf_format: save_hf_format_accelerate( @@ -812,10 +813,18 @@ def save_checkpoint( is_lora=is_lora, epoch=epoch, samples_seen=samples_seen, + global_step=global_step, ) -def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int): +def save_full_state( + args, + accelerator, + is_lora: bool, + epoch: int, + samples_seen: int, + global_step: int | None = None, +): """ Saves model, optimizer, and lr_scheduler state. TODO: save model config - decided not to do this. @@ -848,9 +857,11 @@ def _get_state_dict_patched(model, unwrap=False): # save metadata file for current training status if accelerator.is_main_process: - # TODO: should we set the global_step here rather than calculating global_step - # based on samples_seen? metadata = {"current_epoch": epoch, "samples_seen": samples_seen} + # Save global_step when provided (on-demand mid-epoch checkpoints) + # so that resume can fast-forward to the exact training step. + if global_step is not None: + metadata["global_step"] = global_step torch.save(metadata, output_dir / "training_metadata.json") log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True) @@ -895,10 +906,22 @@ def load_latest_full_state(args, accelerator) -> None: f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True ) - # previous epoch is basis for current epoch. - args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 args.__dict__["samples_seen"] = training_metadata["samples_seen"] + if "global_step" in training_metadata: + # On-demand mid-epoch checkpoint: resume at the same epoch and + # fast-forward to the exact step via last_step. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + args.__dict__["last_step"] = training_metadata["global_step"] + log_rank_0( + f"\033[93mResuming mid-epoch: epoch={args.current_epoch}, " + f"last_step={args.last_step}\033[0m", + to_print=True, + ) + else: + # Epoch-boundary checkpoint: start at the next epoch. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 + def freeze_router_params(model: Model): """