Skip to content

Commit 91846a9

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 872295675
1 parent ab37bb0 commit 91846a9

6 files changed

Lines changed: 177 additions & 18 deletions

File tree

checkpoint/orbax/checkpoint/_src/path/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,24 @@ async def await_creation(self) -> Path:
6767
The path that was created.
6868
"""
6969
...
70+
71+
72+
async def await_and_resolve_as_posix(
73+
path: Path | PathAwaitingCreation,
74+
) -> str:
75+
"""Resolves a path to a POSIX string, awaiting creation if necessary.
76+
77+
For `PathAwaitingCreation` inputs, this blocks until the path has been
78+
created before resolving it to a string. For regular ``Path`` inputs, it
79+
returns the POSIX string immediately.
80+
81+
Args:
82+
path: An `epath.Path` or a `PathAwaitingCreation`.
83+
84+
Returns:
85+
The POSIX string representation of the resolved path.
86+
"""
87+
if isinstance(path, PathAwaitingCreation):
88+
resolved = await path.await_creation()
89+
return resolved.as_posix()
90+
return path.as_posix()
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
17+
from absl.testing import absltest
18+
from etils import epath
19+
from orbax.checkpoint._src.path import types as path_types
20+
from orbax.checkpoint.google.path import tfhub_atomicity
21+
22+
23+
class AwaitAndResolveAsPosixTest(absltest.TestCase):
24+
25+
def test_resolves_epath_immediately(self):
26+
async def _test():
27+
path = epath.Path('/tmp/test/dir')
28+
result = await path_types.await_and_resolve_as_posix(path)
29+
self.assertEqual(result, '/tmp/test/dir')
30+
31+
asyncio.run(_test())
32+
33+
def test_resolves_deferred_path_after_set(self):
34+
async def _test():
35+
directory = self.create_tempdir().full_path
36+
deferred = tfhub_atomicity.DeferredPath()
37+
deferred.set_path(epath.Path(directory))
38+
result = await path_types.await_and_resolve_as_posix(deferred)
39+
self.assertEqual(result, directory)
40+
41+
asyncio.run(_test())
42+
43+
def test_blocks_until_deferred_path_set(self):
44+
async def _test():
45+
directory = self.create_tempdir().full_path
46+
deferred = tfhub_atomicity.DeferredPath()
47+
task = asyncio.create_task(
48+
path_types.await_and_resolve_as_posix(deferred)
49+
)
50+
51+
await asyncio.sleep(0.1)
52+
self.assertFalse(task.done())
53+
54+
deferred.set_path(epath.Path(directory))
55+
result = await task
56+
self.assertEqual(result, directory)
57+
58+
asyncio.run(_test())
59+
60+
def test_resolves_child_deferred_path(self):
61+
async def _test():
62+
directory = self.create_tempdir().full_path
63+
deferred = tfhub_atomicity.DeferredPath()
64+
child = deferred / 'subdir'
65+
task = asyncio.create_task(path_types.await_and_resolve_as_posix(child))
66+
67+
await asyncio.sleep(0.1)
68+
self.assertFalse(task.done())
69+
70+
deferred.set_path(epath.Path(directory))
71+
result = await task
72+
self.assertEqual(result, directory + '/subdir')
73+
74+
asyncio.run(_test())
75+
76+
77+
if __name__ == '__main__':
78+
absltest.main()

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import warnings
2525

2626
from absl import logging
27+
from etils import epath
2728
import humanize
2829
import jax
2930
import jax.numpy as jnp
@@ -38,6 +39,7 @@
3839
from orbax.checkpoint._src.multihost import multihost
3940
from orbax.checkpoint._src.multihost import multislice
4041
from orbax.checkpoint._src.path import async_path
42+
from orbax.checkpoint._src.path import types as path_types
4143
from orbax.checkpoint._src.path import utils as path_utils
4244
from orbax.checkpoint._src.serialization import jax_array_restore_args
4345
from orbax.checkpoint._src.serialization import limits
@@ -157,9 +159,8 @@ async def _async_serialize_shardings(
157159
continue
158160
if info.parent_dir is None:
159161
raise ValueError('parent_dir cannot be None')
160-
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
161-
info.parent_dir.as_posix(), info.name
162-
)
162+
dir_str = await path_types.await_and_resolve_as_posix(info.parent_dir)
163+
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(dir_str, info.name)
163164
if multihost.is_primary_host(primary_host):
164165
# OCDBT is not used for sharding metadata.
165166
sharding_ts_context = info.ts_context
@@ -568,8 +569,14 @@ async def _async_serialize_replica_slices(
568569
'Replica_separate_folder is disabled as OCDBT is not enabled.',
569570
1,
570571
)
572+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
573+
info.parent_dir
574+
)
575+
resolved_info = dataclasses.replace(
576+
info, parent_dir=epath.Path(resolved_parent_dir)
577+
)
571578
array_write_spec = ts_utils.build_array_write_spec(
572-
info=info,
579+
info=resolved_info,
573580
arg=arg,
574581
global_shape=value.global_shape,
575582
local_shape=value.local_shape,
@@ -718,8 +725,9 @@ async def _deserialize_shardings(
718725
)
719726
assert info.parent_dir is not None
720727
if info.name:
728+
dir_str = await path_types.await_and_resolve_as_posix(info.parent_dir)
721729
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
722-
info.parent_dir.as_posix(), info.name
730+
dir_str, info.name
723731
)
724732
t = await ts.open(
725733
tspec_sharding,
@@ -765,8 +773,14 @@ async def _async_deserialize(
765773
await _validate_non_ocdbt_files(infos, metadata_key)
766774
deserialize_ops = []
767775
for info, arg, sharding in zip(infos, args, shardings):
776+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
777+
info.parent_dir
778+
)
779+
resolved_info = dataclasses.replace(
780+
info, parent_dir=epath.Path(resolved_parent_dir)
781+
)
768782
array_read_spec = ts_utils.build_array_read_spec(
769-
info,
783+
resolved_info,
770784
use_ocdbt=use_ocdbt,
771785
metadata_key=metadata_key,
772786
raise_array_data_missing_error=info.raise_array_data_missing_error,
@@ -1010,8 +1024,14 @@ async def metadata(
10101024
for info in infos:
10111025
# Use OCDBT flag from the existing checkpoint.
10121026
use_ocdbt = info.is_ocdbt_checkpoint
1027+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
1028+
info.parent_dir
1029+
)
1030+
resolved_info = dataclasses.replace(
1031+
info, parent_dir=epath.Path(resolved_parent_dir)
1032+
)
10131033
array_read_spec = ts_utils.build_array_read_spec(
1014-
info,
1034+
resolved_info,
10151035
use_ocdbt=use_ocdbt,
10161036
metadata_key=self._metadata_key,
10171037
raise_array_data_missing_error=info.raise_array_data_missing_error,
@@ -1024,8 +1044,9 @@ async def metadata(
10241044
assert info.parent_dir is not None
10251045
sharding_op = None
10261046
if info.name:
1047+
dir_str = await path_types.await_and_resolve_as_posix(info.parent_dir)
10271048
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
1028-
info.parent_dir.as_posix(), info.name
1049+
dir_str, info.name
10291050
)
10301051
if sharding_file_exists:
10311052
sharding_op = ts.open(

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def _get_json_tspec(
648648
if info.name is None or info.parent_dir is None:
649649
raise ValueError('Must provide info.name and info.parent_dir.')
650650
parent_dir = info.parent_dir
651-
assert parent_dir is not None
651+
assert isinstance(parent_dir, epath.Path)
652652
directory = parent_dir.as_posix()
653653
kvstore_tspec = build_kvstore_tspec(
654654
directory,
@@ -750,6 +750,7 @@ def build_array_read_spec(
750750
"""Gets ArrayReadSpec for reading."""
751751
if info.name is None or info.parent_dir is None:
752752
raise ValueError('Must provide info.name and info.parent_dir.')
753+
assert isinstance(info.parent_dir, epath.Path)
753754
return ArrayReadSpec(
754755
directory=info.parent_dir.as_posix(),
755756
relative_array_filename=info.name,
@@ -778,7 +779,7 @@ def build_array_write_spec(
778779
if info.name is None or info.parent_dir is None:
779780
raise ValueError('Must provide info.name and info.parent_dir.')
780781
parent_dir = info.parent_dir
781-
assert parent_dir is not None
782+
assert isinstance(parent_dir, epath.Path)
782783
directory = parent_dir.as_posix()
783784

784785
return ArrayWriteSpec(

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@
1818

1919
import asyncio
2020
import copy
21+
import dataclasses
2122
import sys
2223
from typing import Any, Dict, Optional, Sequence, Tuple, TypeAlias, Union
2324

2425
from absl import logging
26+
from etils import epath
2527
import jax
2628
import numpy as np
2729
from orbax.checkpoint._src.futures import future
2830
from orbax.checkpoint._src.metadata import value as value_metadata
2931
from orbax.checkpoint._src.multihost import multihost
3032
from orbax.checkpoint._src.path import format_utils
33+
from orbax.checkpoint._src.path import types as path_types
3134
from orbax.checkpoint._src.serialization import jax_array_handlers
3235
from orbax.checkpoint._src.serialization import ocdbt_utils
3336
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
@@ -85,10 +88,16 @@ async def metadata(
8588
) -> Sequence[value_metadata.ArrayMetadata]:
8689
open_ops = []
8790
for info in infos:
91+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
92+
info.parent_dir
93+
)
94+
resolved_info = dataclasses.replace(
95+
info, parent_dir=epath.Path(resolved_parent_dir)
96+
)
8897
# Use OCDBT flag from the existing checkpoint.
8998
use_ocdbt = info.is_ocdbt_checkpoint
9099
array_read_spec = ts_utils.build_array_read_spec(
91-
info,
100+
resolved_info,
92101
use_ocdbt=use_ocdbt,
93102
metadata_key=self._metadata_key,
94103
raise_array_data_missing_error=info.raise_array_data_missing_error,
@@ -122,8 +131,14 @@ async def _background_serialize(
122131
"""Serializes numpy arrays in a background thread."""
123132
write_coros = []
124133
for value, info, arg in zip(values, infos, args):
134+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
135+
info.parent_dir
136+
)
137+
resolved_info = dataclasses.replace(
138+
info, parent_dir=epath.Path(resolved_parent_dir)
139+
)
125140
array_write_spec = ts_utils.build_array_write_spec(
126-
info=info,
141+
info=resolved_info,
127142
arg=arg,
128143
global_shape=value.shape,
129144
local_shape=value.shape,
@@ -175,14 +190,22 @@ async def deserialize(
175190
types.check_input_arguments(infos, args)
176191
open_futures = []
177192
for info, arg in zip(infos, args):
193+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
194+
info.parent_dir
195+
)
196+
resolved_info = dataclasses.replace(
197+
info, parent_dir=epath.Path(resolved_parent_dir)
198+
)
178199
if not info.is_ocdbt_checkpoint:
179200
await ts_utils.assert_parameter_files_exist(
180-
info.parent_dir / info.name, self._metadata_key, info.use_zarr3
201+
resolved_info.parent_dir / resolved_info.name,
202+
self._metadata_key,
203+
resolved_info.use_zarr3,
181204
)
182205
# Use OCDBT flag from the existing checkpoint.
183206
use_ocdbt = info.is_ocdbt_checkpoint
184207
array_read_spec = ts_utils.build_array_read_spec(
185-
info,
208+
resolved_info,
186209
use_ocdbt=use_ocdbt,
187210
metadata_key=self._metadata_key,
188211
raise_array_data_missing_error=info.raise_array_data_missing_error,
@@ -293,6 +316,7 @@ def _get_json_tspec(
293316
"""Gets Tensorstore spec in JSON format."""
294317
if info.parent_dir is None:
295318
raise ValueError('Must provide info.parent_dir.')
319+
assert isinstance(info.parent_dir, epath.Path)
296320
directory = (info.parent_dir / self._filename).as_posix()
297321
kvstore_tspec = ts_utils.build_kvstore_tspec(directory, use_ocdbt=False)
298322
tspec = {
@@ -330,7 +354,13 @@ async def _background_serialize(
330354
info,
331355
value,
332356
) in zip(infos, values):
333-
tspec = self._get_json_tspec(info)
357+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
358+
info.parent_dir
359+
)
360+
resolved_info = dataclasses.replace(
361+
info, parent_dir=epath.Path(resolved_parent_dir)
362+
)
363+
tspec = self._get_json_tspec(resolved_info)
334364
if multihost.process_index() == 0:
335365
t = await ts.open(
336366
tspec,
@@ -368,7 +398,13 @@ async def deserialize(
368398
open_futures = []
369399

370400
for info in infos:
371-
tspec = self._get_json_tspec(info)
401+
resolved_parent_dir = await path_types.await_and_resolve_as_posix(
402+
info.parent_dir
403+
)
404+
resolved_info = dataclasses.replace(
405+
info, parent_dir=epath.Path(resolved_parent_dir)
406+
)
407+
tspec = self._get_json_tspec(resolved_info)
372408
open_future = ts.open(
373409
tspec, open=True, read=True, context=self._ts_context
374410
)

checkpoint/orbax/checkpoint/_src/serialization/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
3232
from orbax.checkpoint._src.metadata import value as value_metadata
3333
from orbax.checkpoint._src.serialization import limits
34+
from orbax.checkpoint.experimental.v1._src.path import types as path_types
3435
import tensorstore as ts
3536

37+
3638
PyTreeMetadataOptions = pytree_metadata_options_lib.PyTreeMetadataOptions
3739

3840

@@ -115,8 +117,8 @@ class ParamInfo:
115117
"""
116118

117119
name: str
118-
parent_dir: epath.Path
119-
path: Optional[epath.Path] = None
120+
parent_dir: epath.Path | path_types.PathAwaitingCreation
121+
path: epath.Path | path_types.PathAwaitingCreation | None = None
120122
keypath: Optional[Tuple[Any, ...]] = None
121123
skip_deserialize: Optional[bool] = None
122124
byte_limiter: Optional[limits.ByteLimiter] = None

0 commit comments

Comments
 (0)