diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md
index 9fa4e7192a..56004c2466 100644
--- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md
+++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md
@@ -29,6 +29,8 @@ The system follows a specific order when deciding which checkpoint to load at st
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
+| `checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` | `string` | `""` |
+| `checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` |
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py
index 220ff6f16d..8fef245a51 100644
--- a/src/maxtext/common/checkpointing.py
+++ b/src/maxtext/common/checkpointing.py
@@ -221,6 +221,8 @@ def create_orbax_checkpoint_manager(
enable_single_controller: bool = False,
colocated_python_checkpointing: bool = False,
enable_single_replica_ckpt_restoring: bool = False,
+ todelete_subdir: str | None = None,
+ todelete_full_path: str | None = None,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
@@ -268,6 +270,8 @@ def create_orbax_checkpoint_manager(
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
async_options=async_options,
+ todelete_subdir=todelete_subdir,
+ todelete_full_path=todelete_full_path,
),
logger=orbax_logger,
)
diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml
index eea7b9c2d0..bf7e9dc8e1 100644
--- a/src/maxtext/configs/base.yml
+++ b/src/maxtext/configs/base.yml
@@ -59,6 +59,10 @@ max_num_checkpoints_to_keep: None
enable_continuous_checkpointing: False
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False
+# Subdirectory to move checkpoints to before deletion. For example: ".todelete"
+checkpoint_todelete_subdir: ""
+# Full path to move checkpoints to before deletion.
+checkpoint_todelete_full_path: ""
force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?
diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py
index cdb61a25fc..69b7169485 100644
--- a/src/maxtext/configs/types.py
+++ b/src/maxtext/configs/types.py
@@ -309,6 +309,8 @@ class Checkpointing(BaseModel):
enable_single_replica_ckpt_restoring: bool = Field(
False, description="One replica reads and broadcasts the checkpoint."
)
+ checkpoint_todelete_subdir: str = Field("", description="Subdirectory to move checkpoints to before deletion.")
+ checkpoint_todelete_full_path: str = Field("", description="Full path to move checkpoints to before deletion.")
force_unroll: bool = Field(
False,
description="During param-only checkpoint generation, whether to unroll the loop.",