Skip to content

Commit cf21b0b

Browse files
author
Orbax Authors
committed
Internal change
PiperOrigin-RevId: 868682897
1 parent 9504bfc commit cf21b0b

2 files changed

Lines changed: 340 additions & 12 deletions

File tree

checkpoint/orbax/checkpoint/_src/arrays/fragments.py

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,28 @@ class methods:
2424
FS is `AbstractFragments` then x may additionally be a
2525
`jax.ShapeDtypeStruct`.
2626
- `none_of(x)`: gives an FS shaped like x, with no fragments.
27+
- `addressable_shards_of(x)`: gives an FS shaped like x, with one fragment for
28+
each distinct addressable shard of x. If FS is `AbstractFragments` then x
29+
may additionally be a `jax.ShapeDtypeStruct`.
30+
- `of(x, *indices)`: gives an FS shaped like x, with one fragment for each
31+
index in `indices`. If FS is concrete then the fragment values will be
32+
slices of x. If FS is `AbstractFragments` then x may additionally be
33+
a `jax.ShapeDtypeStruct`.
34+
- `like(fragments, x)` takes an existing Fragments instance and returns a new
35+
instance of FS, with the same shape and dtype and fragment indices, but
36+
with the fragment values replaced by slices of x. If FS is
37+
`AbstractFragments` then x must be `None` and may be omitted.
38+
39+
Use `FS.of()` and friends to make instances out of arraylike things.
40+
Use `FS.like()` to convert between different kinds of Fragments.
41+
Use `{np, jnp}.asarray(fragments)` to make arraylike things out of (full)
42+
Fragments.
2743
"""
2844
# TODO(b/465196209): Remove when support for Python 3.10 is dropped.
2945
from __future__ import annotations
3046

3147
import dataclasses
32-
from typing import Any, ClassVar, Generic, Literal, Sequence, TypeAlias, TypeVar
48+
from typing import Any, ClassVar, Generator, Generic, Literal, Sequence, TypeAlias, TypeVar
3349

3450
import jax
3551
import numpy as np
@@ -369,6 +385,16 @@ def none_of(cls: type[FS], x: Any) -> FS:
369385
"""Returns a Fragments with no fragments."""
370386
return cls._of(x, indices=[])
371387

388+
@classmethod
389+
def addressable_shards_of(cls: type[FS], x: Any) -> FS:
390+
"""Returns a Fragments exactly spanning the distinct addressable shards of `x`."""
391+
return cls._of(x, indices=_gen_distinct_addressable_indices(x))
392+
393+
@classmethod
394+
def of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
395+
"""Returns a Fragments exactly spanning the given indices."""
396+
return cls._of(x, indices=indices)
397+
372398
def is_degenerate(self) -> bool:
373399
"""Whether this contains only degenerate fragments."""
374400
return all(f.is_degenerate() for f in self.fragments)
@@ -461,11 +487,26 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
461487
fragments = [cls.FRAGMENT_T(index=index) for index in indices]
462488
return cls(x.shape, x.dtype, fragments)
463489

490+
@classmethod
491+
def like(
492+
cls: type[FS],
493+
fragments: _GenericFragments[Any],
494+
value: Literal[None] = None,
495+
) -> FS:
496+
del value
497+
return cls(
498+
shape=fragments.shape,
499+
dtype=fragments.dtype,
500+
fragments=[
501+
cls.FRAGMENT_T(index=f.index) for f in fragments.fragments
502+
],
503+
)
504+
464505

465506
@dataclasses.dataclass(frozen=True, init=False)
466-
class NpFragments(_GenericFragments[NpFragment]):
467-
"""A collection of fragments whose values are of type `np.ndarray`."""
468-
FRAGMENT_T = NpFragment
507+
class _ConcreteFragments(_GenericFragments[Fconcrete]):
508+
"""A collection of concrete fragments."""
509+
FRAGMENT_T: ClassVar[type[Fconcrete]] # The type of fragment values.
469510

470511
@classmethod
471512
def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
@@ -474,19 +515,37 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
474515
fragments = [cls.FRAGMENT_T(index=i, value=x[i]) for i in indices]
475516
return cls(x.shape, x.dtype, fragments)
476517

518+
@classmethod
519+
def like(
520+
cls: type[FS], fragments: _GenericFragments[Any], value: Aconcrete
521+
) -> FS:
522+
_check_fragment_value_type(value, cls.FRAGMENT_T.ARRAY_T)
523+
if fragments.shape != value.shape or fragments.dtype != value.dtype:
524+
raise ValueError(
525+
f'Fragments type {fragments.dtype}[{fragments.shape}] does'
526+
f' not match value type {value.dtype}[{value.shape}].'
527+
)
528+
return cls(
529+
shape=fragments.shape,
530+
dtype=fragments.dtype,
531+
fragments=[
532+
cls.FRAGMENT_T(index=f.index, value=value[f.index])
533+
for f in fragments.fragments
534+
],
535+
)
536+
477537

478538
@dataclasses.dataclass(frozen=True, init=False)
479-
class JaxFragments(_GenericFragments[JaxFragment]):
539+
class NpFragments(_ConcreteFragments[NpFragment]):
540+
"""A collection of fragments whose values are of type `np.ndarray`."""
541+
FRAGMENT_T = NpFragment
542+
543+
544+
@dataclasses.dataclass(frozen=True, init=False)
545+
class JaxFragments(_ConcreteFragments[JaxFragment]):
480546
"""A collection of fragments whose values are of type `jax.Array`."""
481547
FRAGMENT_T = JaxFragment
482548

483-
@classmethod
484-
def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
485-
"""Returns a Fragments with one fragment for each index."""
486-
_check_fragment_value_type(x, cls.FRAGMENT_T.ARRAY_T)
487-
fragments = [cls.FRAGMENT_T(index=i, value=x[i]) for i in indices]
488-
return cls(x.shape, x.dtype, fragments)
489-
490549

491550
# Extra names for backwards compatibility. Most loading and saving code still
492551
# wants to deal with NumPy arrays so that views and operations on them
@@ -534,6 +593,28 @@ def addressable_shards(x: jax.Array | jax.ShapeDtypeStruct) -> list[Index]:
534593
]
535594

536595

596+
def _gen_distinct_addressable_indices(
597+
x: np.ndarray | jax.Array | jax.ShapeDtypeStruct,
598+
) -> Generator[Index, None, None]:
599+
"""Yields fragment indices for distinct addressable shards of x."""
600+
match x:
601+
case jax.Array() | jax.ShapeDtypeStruct():
602+
if not x.sharding:
603+
raise ValueError(
604+
'Cannot determine addressable shards of jax.ShapeDtypeStruct with'
605+
' no sharding.'
606+
)
607+
indices = addressable_shards(x)
608+
case np.ndarray():
609+
indices = (tuple(slice(0, dim, 1) for dim in x.shape),)
610+
case _:
611+
raise TypeError(f'Unsupported type: {type(x)}')
612+
distinct_indices = sorted({
613+
*(np_utils.to_hashable_index(i, shape=x.shape) for i in indices)
614+
})
615+
yield from (np_utils.from_hashable_index(i) for i in distinct_indices)
616+
617+
537618
def abstract_fragments(
538619
x: jax.Array | jax.ShapeDtypeStruct | FS,
539620
) -> AbstractFragments:

0 commit comments

Comments
 (0)