|
18 | 18 |
|
19 | 19 | import asyncio |
20 | 20 | import copy |
| 21 | +import dataclasses |
21 | 22 | import sys |
22 | 23 | from typing import Any, Dict, Optional, Sequence, Tuple, TypeAlias, Union |
23 | 24 |
|
24 | 25 | from absl import logging |
| 26 | +from etils import epath |
25 | 27 | import jax |
26 | 28 | import numpy as np |
27 | 29 | from orbax.checkpoint._src.futures import future |
28 | 30 | from orbax.checkpoint._src.metadata import value as value_metadata |
29 | 31 | from orbax.checkpoint._src.multihost import multihost |
30 | 32 | from orbax.checkpoint._src.path import format_utils |
| 33 | +from orbax.checkpoint._src.path import types as path_types |
31 | 34 | from orbax.checkpoint._src.serialization import jax_array_handlers |
32 | 35 | from orbax.checkpoint._src.serialization import ocdbt_utils |
33 | 36 | from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils |
@@ -85,10 +88,16 @@ async def metadata( |
85 | 88 | ) -> Sequence[value_metadata.ArrayMetadata]: |
86 | 89 | open_ops = [] |
87 | 90 | for info in infos: |
| 91 | + resolved_parent_dir = await path_types.await_and_resolve_as_posix( |
| 92 | + info.parent_dir |
| 93 | + ) |
| 94 | + resolved_info = dataclasses.replace( |
| 95 | + info, parent_dir=epath.Path(resolved_parent_dir) |
| 96 | + ) |
88 | 97 | # Use OCDBT flag from the existing checkpoint. |
89 | 98 | use_ocdbt = info.is_ocdbt_checkpoint |
90 | 99 | array_read_spec = ts_utils.build_array_read_spec( |
91 | | - info, |
| 100 | + resolved_info, |
92 | 101 | use_ocdbt=use_ocdbt, |
93 | 102 | metadata_key=self._metadata_key, |
94 | 103 | raise_array_data_missing_error=info.raise_array_data_missing_error, |
@@ -122,8 +131,14 @@ async def _background_serialize( |
122 | 131 | """Serializes numpy arrays in a background thread.""" |
123 | 132 | write_coros = [] |
124 | 133 | for value, info, arg in zip(values, infos, args): |
| 134 | + resolved_parent_dir = await path_types.await_and_resolve_as_posix( |
| 135 | + info.parent_dir |
| 136 | + ) |
| 137 | + resolved_info = dataclasses.replace( |
| 138 | + info, parent_dir=epath.Path(resolved_parent_dir) |
| 139 | + ) |
125 | 140 | array_write_spec = ts_utils.build_array_write_spec( |
126 | | - info=info, |
| 141 | + info=resolved_info, |
127 | 142 | arg=arg, |
128 | 143 | global_shape=value.shape, |
129 | 144 | local_shape=value.shape, |
@@ -175,14 +190,22 @@ async def deserialize( |
175 | 190 | types.check_input_arguments(infos, args) |
176 | 191 | open_futures = [] |
177 | 192 | for info, arg in zip(infos, args): |
| 193 | + resolved_parent_dir = await path_types.await_and_resolve_as_posix( |
| 194 | + info.parent_dir |
| 195 | + ) |
| 196 | + resolved_info = dataclasses.replace( |
| 197 | + info, parent_dir=epath.Path(resolved_parent_dir) |
| 198 | + ) |
178 | 199 | if not info.is_ocdbt_checkpoint: |
179 | 200 | await ts_utils.assert_parameter_files_exist( |
180 | | - info.parent_dir / info.name, self._metadata_key, info.use_zarr3 |
| 201 | + resolved_info.parent_dir / resolved_info.name, |
| 202 | + self._metadata_key, |
| 203 | + resolved_info.use_zarr3, |
181 | 204 | ) |
182 | 205 | # Use OCDBT flag from the existing checkpoint. |
183 | 206 | use_ocdbt = info.is_ocdbt_checkpoint |
184 | 207 | array_read_spec = ts_utils.build_array_read_spec( |
185 | | - info, |
| 208 | + resolved_info, |
186 | 209 | use_ocdbt=use_ocdbt, |
187 | 210 | metadata_key=self._metadata_key, |
188 | 211 | raise_array_data_missing_error=info.raise_array_data_missing_error, |
@@ -293,6 +316,7 @@ def _get_json_tspec( |
293 | 316 | """Gets Tensorstore spec in JSON format.""" |
294 | 317 | if info.parent_dir is None: |
295 | 318 | raise ValueError('Must provide info.parent_dir.') |
| 319 | + assert isinstance(info.parent_dir, epath.Path) |
296 | 320 | directory = (info.parent_dir / self._filename).as_posix() |
297 | 321 | kvstore_tspec = ts_utils.build_kvstore_tspec(directory, use_ocdbt=False) |
298 | 322 | tspec = { |
@@ -330,7 +354,13 @@ async def _background_serialize( |
330 | 354 | info, |
331 | 355 | value, |
332 | 356 | ) in zip(infos, values): |
333 | | - tspec = self._get_json_tspec(info) |
| 357 | + resolved_parent_dir = await path_types.await_and_resolve_as_posix( |
| 358 | + info.parent_dir |
| 359 | + ) |
| 360 | + resolved_info = dataclasses.replace( |
| 361 | + info, parent_dir=epath.Path(resolved_parent_dir) |
| 362 | + ) |
| 363 | + tspec = self._get_json_tspec(resolved_info) |
334 | 364 | if multihost.process_index() == 0: |
335 | 365 | t = await ts.open( |
336 | 366 | tspec, |
@@ -368,7 +398,13 @@ async def deserialize( |
368 | 398 | open_futures = [] |
369 | 399 |
|
370 | 400 | for info in infos: |
371 | | - tspec = self._get_json_tspec(info) |
| 401 | + resolved_parent_dir = await path_types.await_and_resolve_as_posix( |
| 402 | + info.parent_dir |
| 403 | + ) |
| 404 | + resolved_info = dataclasses.replace( |
| 405 | + info, parent_dir=epath.Path(resolved_parent_dir) |
| 406 | + ) |
| 407 | + tspec = self._get_json_tspec(resolved_info) |
372 | 408 | open_future = ts.open( |
373 | 409 | tspec, open=True, read=True, context=self._ts_context |
374 | 410 | ) |
|
0 commit comments