4848from orbax .checkpoint ._src .multihost import multihost
4949from orbax .checkpoint ._src .path import async_path
5050from orbax .checkpoint ._src .path import format_utils
51+ from orbax .checkpoint ._src .path import types as path_types
5152from orbax .checkpoint ._src .serialization import limits
5253from orbax .checkpoint ._src .serialization import ocdbt_utils
5354from orbax .checkpoint ._src .serialization import tensorstore_utils as ts_utils
@@ -312,7 +313,7 @@ def _format_bytes(bytes_value: Optional[int]) -> str:
312313
313314
314315class BasePyTreeCheckpointHandler (
315- async_checkpoint_handler .AsyncCheckpointHandler
316+ async_checkpoint_handler .DeferredPathAsyncCheckpointHandler
316317):
317318 """A CheckpointHandler implementation for any PyTree structure.
318319
@@ -585,7 +586,7 @@ def _handle_diffs(keypath, diff):
585586
586587 async def async_save (
587588 self ,
588- directory : epath .Path ,
589+ directory : epath .Path | path_types . PathAwaitingCreation ,
589590 args : BasePyTreeSaveArgs ,
590591 ) -> Optional [List [future .Future ]]:
591592 """Saves a PyTree to a given directory.
@@ -648,7 +649,7 @@ async def async_save(
648649 use_zarr3 = self ._use_zarr3 ,
649650 )
650651 assert all (
651- leaf .parent_dir == directory for leaf in jax .tree .leaves (param_infos )
652+ leaf .parent_dir is directory for leaf in jax .tree .leaves (param_infos )
652653 )
653654
654655 serialize_ops = [] # List of (coros -> List of futures)
@@ -663,12 +664,16 @@ async def async_save(
663664 # suffix in the checkpoint directory name and if the metadata file exists.
664665 # Cannot rely solely on the metadata file existing pre-empted saves may be
665666 # misclassified as partial saves.
666- partial_save = (
667- await async_path .exists (directory / PYTREE_METADATA_FILE )
668- # TODO: b/428711337 - Use method from v1/_src/partial/path.py instead.
667+ # TODO(b/484298759): Remove the path-based check once all callers use
668+ # partial_save_mode.
669+ is_partial_save_path = (
670+ isinstance (directory , epath .Path )
669671 and '.partial_save' in directory .parent .name
670672 )
671-
673+ is_partial_save = is_partial_save_path
674+ partial_save = is_partial_save and await async_path .exists (
675+ directory / PYTREE_METADATA_FILE
676+ )
672677 batch_requests_ready_time = time .time ()
673678 if partial_save :
674679 serialize_ops , tree_memory_size , param_infos , save_args = (
@@ -1412,6 +1417,8 @@ class BasePyTreeSaveArgs(CheckpointArgs):
14121417 custom_metadata: User-provided custom metadata. An arbitrary
14131418 JSON-serializable dictionary the user can use to store additional
14141419 information. The field is treated as opaque by Orbax.
1420+ partial_save_mode: When True, signals that this save is a partial save
1421+ operation. The handler will merge the new data with existing checkpoint
14151422 """
14161423
14171424 item : PyTree
0 commit comments