Skip to content

Commit c1d3cd1

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 871839606
1 parent cf804e9 commit c1d3cd1

2 files changed

Lines changed: 65 additions & 16 deletions

File tree

checkpoint/orbax/checkpoint/_src/handlers/async_checkpoint_handler.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,69 @@
2020
from etils import epath
2121
from orbax.checkpoint._src.futures import future
2222
from orbax.checkpoint._src.handlers import checkpoint_handler
23+
from orbax.checkpoint._src.path import types as path_types
2324

2425

2526
class AsyncCheckpointHandler(checkpoint_handler.CheckpointHandler):
26-
"""An interface providing async methods that can be used with CheckpointHandler."""
27+
"""An interface providing async methods used with AsyncCheckpointer."""
2728

2829
@abc.abstractmethod
2930
async def async_save(
30-
self, directory: epath.Path, *args, **kwargs
31+
self,
32+
directory: epath.Path,
33+
*args,
34+
**kwargs,
3135
) -> Optional[List[future.Future]]:
32-
"""Constructs a save operation.
36+
"""Saves the given item to the provided directory.
3337
34-
Synchronously awaits a copy of the item, before returning commit futures
35-
necessary to save the item.
38+
Args:
39+
directory: the directory to save to.
40+
*args: additional arguments for save.
41+
**kwargs: additional arguments for save.
42+
43+
Returns:
44+
A list of commit futures which can be awaited upon to complete the save
45+
operation.
46+
"""
47+
pass
48+
49+
50+
class DeferredPathAsyncCheckpointHandler(AsyncCheckpointHandler):
51+
"""Handler interface that receives Path or PathAwaitingCreation.
3652
37-
Note: Any operations on directory should be done by using
38-
`future.CommitFutureAwaitingContractedSignals` to wait for directories to be
39-
created.
53+
This interface extends AsyncCheckpointHandler with an async_save method that
54+
accepts either an epath.Path or PathAwaitingCreation, allowing handlers to
55+
work with deferred paths (e.g., TFHub) where the actual path is allocated
56+
asynchronously.
57+
58+
Handlers implementing this interface can:
59+
1. Receive a deferred path representation before the path is allocated
60+
2. Wait for STEP_DIRECTORY_CREATION signal inside their CommitFuture
61+
3. Access the path via await_creation() or .path after the signal
62+
"""
63+
64+
@abc.abstractmethod
65+
async def async_save(
66+
self,
67+
directory: epath.Path | path_types.PathAwaitingCreation,
68+
*args,
69+
**kwargs,
70+
) -> Optional[List[future.Future]]:
71+
"""Constructs a save operation with support for deferred paths.
72+
73+
This method accepts an epath.Path or PathAwaitingCreation.
74+
When a deferred path is passed, handler coroutines should wait for the
75+
STEP_DIRECTORY_CREATION signal before accessing the path.
4076
4177
Args:
42-
directory: the directory to save to.
78+
directory: The directory to save to. May be an epath.Path or
79+
PathAwaitingCreation. For deferred paths, await_creation() or signal
80+
ordering ensures the path is available.
4381
*args: additional arguments for save.
4482
**kwargs: additional arguments for save.
83+
84+
Returns:
85+
A list of futures that will commit the data when awaited.
4586
"""
87+
4688
pass

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from orbax.checkpoint._src.multihost import multihost
4949
from orbax.checkpoint._src.path import async_path
5050
from orbax.checkpoint._src.path import format_utils
51+
from orbax.checkpoint._src.path import types as path_types
5152
from orbax.checkpoint._src.serialization import limits
5253
from orbax.checkpoint._src.serialization import ocdbt_utils
5354
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
@@ -312,7 +313,7 @@ def _format_bytes(bytes_value: Optional[int]) -> str:
312313

313314

314315
class BasePyTreeCheckpointHandler(
315-
async_checkpoint_handler.AsyncCheckpointHandler
316+
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
316317
):
317318
"""A CheckpointHandler implementation for any PyTree structure.
318319
@@ -585,7 +586,7 @@ def _handle_diffs(keypath, diff):
585586

586587
async def async_save(
587588
self,
588-
directory: epath.Path,
589+
directory: epath.Path | path_types.PathAwaitingCreation,
589590
args: BasePyTreeSaveArgs,
590591
) -> Optional[List[future.Future]]:
591592
"""Saves a PyTree to a given directory.
@@ -648,7 +649,7 @@ async def async_save(
648649
use_zarr3=self._use_zarr3,
649650
)
650651
assert all(
651-
leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
652+
leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos)
652653
)
653654

654655
serialize_ops = [] # List of (coros -> List of futures)
@@ -663,12 +664,16 @@ async def async_save(
663664
# suffix in the checkpoint directory name and if the metadata file exists.
664665
# Cannot rely solely on the metadata file existing pre-empted saves may be
665666
# misclassified as partial saves.
666-
partial_save = (
667-
await async_path.exists(directory / PYTREE_METADATA_FILE)
668-
# TODO: b/428711337 - Use method from v1/_src/partial/path.py instead.
667+
# TODO(b/484298759): Remove the path-based check once all callers use
668+
# partial_save_mode.
669+
is_partial_save_path = (
670+
isinstance(directory, epath.Path)
669671
and '.partial_save' in directory.parent.name
670672
)
671-
673+
is_partial_save = is_partial_save_path
674+
partial_save = is_partial_save and await async_path.exists(
675+
directory / PYTREE_METADATA_FILE
676+
)
672677
batch_requests_ready_time = time.time()
673678
if partial_save:
674679
serialize_ops, tree_memory_size, param_infos, save_args = (
@@ -1412,6 +1417,8 @@ class BasePyTreeSaveArgs(CheckpointArgs):
14121417
custom_metadata: User-provided custom metadata. An arbitrary
14131418
JSON-serializable dictionary the user can use to store additional
14141419
information. The field is treated as opaque by Orbax.
1420+
partial_save_mode: When True, signals that this save is a partial save
1421+
operation. The handler will merge the new data with existing checkpoint
14151422
"""
14161423

14171424
item: PyTree

0 commit comments

Comments
 (0)