diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md index 9fa4e7192a..56004c2466 100644 --- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md +++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md @@ -29,6 +29,8 @@ The system follows a specific order when deciding which checkpoint to load at st | `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | | `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` | | `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` | +| `checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` | `string` | `""` | +| `checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` | | `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) | | `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) | | `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) | diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 220ff6f16d..8fef245a51 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -221,6 +221,8 @@ def create_orbax_checkpoint_manager( enable_single_controller: bool = False, colocated_python_checkpointing: bool = False, enable_single_replica_ckpt_restoring: bool = False, + todelete_subdir: str | None = None, + todelete_full_path: str | None = None, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -268,6 +270,8 @@ def create_orbax_checkpoint_manager( save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, async_options=async_options, + todelete_subdir=todelete_subdir, + todelete_full_path=todelete_full_path, ), logger=orbax_logger, ) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index eea7b9c2d0..bf7e9dc8e1 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -59,6 +59,10 @@ max_num_checkpoints_to_keep: None enable_continuous_checkpointing: False # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False +# Subdirectory to move checkpoints to before deletion. For example: ".todelete" +checkpoint_todelete_subdir: "" +# Full path to move checkpoints to before deletion. +checkpoint_todelete_full_path: "" force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cdb61a25fc..69b7169485 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -309,6 +309,8 @@ class Checkpointing(BaseModel): enable_single_replica_ckpt_restoring: bool = Field( False, description="One replica reads and broadcasts the checkpoint." ) + checkpoint_todelete_subdir: str = Field("", description="Subdirectory to move checkpoints to before deletion.") + checkpoint_todelete_full_path: str = Field("", description="Full path to move checkpoints to before deletion.") force_unroll: bool = Field( False, description="During param-only checkpoint generation, whether to unroll the loop.",