@@ -24,12 +24,28 @@ class methods:
2424 FS is `AbstractFragments` then x may additionally be a
2525 `jax.ShapeDtypeStruct`.
2626 - `none_of(x)`: gives an FS shaped like x, with no fragments.
27+ - `addressable_shards_of(x)`: gives an FS shaped like x, with one fragment for
28+ each distinct addressable shard of x. If FS is `AbstractFragments` then x
29+ may additionally be a `jax.ShapeDtypeStruct`.
30+ - `of(x, *indices)`: gives an FS shaped like x, with one fragment for each
31+ index in `indices`. If FS is concrete then the fragment values will be
32+ slices of x. If FS is `AbstractFragments` then x may additionally be
33+ a `jax.ShapeDtypeStruct`.
34+ - `like(fragments, x)` takes an existing Fragments instance and returns a new
35+ instance of FS, with the same shape and dtype and fragment indices, but
36+ with the fragment values replaced by slices of x. If FS is
37+ `AbstractFragments` then x must be `None` and may be omitted.
38+
39+ Use `FS.of()` and friends to make instances out of arraylike things.
40+ Use `FS.like()` to convert between different kinds of Fragments.
41+ Use `{np, jnp}.asarray(fragments)` to make arraylike things out of (full)
42+ Fragments.
2743"""
2844# TODO(b/465196209): Remove when support for Python 3.10 is dropped.
2945from __future__ import annotations
3046
3147import dataclasses
32- from typing import Any , ClassVar , Generic , Literal , Sequence , TypeAlias , TypeVar
48+ from typing import Any , ClassVar , Generator , Generic , Literal , Sequence , TypeAlias , TypeVar
3349
3450import jax
3551import numpy as np
@@ -369,6 +385,16 @@ def none_of(cls: type[FS], x: Any) -> FS:
369385 """Returns a Fragments with no fragments."""
370386 return cls ._of (x , indices = [])
371387
388+ @classmethod
389+ def addressable_shards_of (cls : type [FS ], x : Any ) -> FS :
390+ """Returns a Fragments exactly spanning the distinct addressable shards of `x`."""
391+ return cls ._of (x , indices = _gen_distinct_addressable_indices (x ))
392+
393+ @classmethod
394+ def of (cls : type [FS ], x : Any , * , indices : Sequence [Index ]) -> FS :
395+ """Returns a Fragments exactly spanning the given indices."""
396+ return cls ._of (x , indices = indices )
397+
372398 def is_degenerate (self ) -> bool :
373399 """Whether this contains only degenerate fragments."""
374400 return all (f .is_degenerate () for f in self .fragments )
@@ -461,11 +487,26 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
461487 fragments = [cls .FRAGMENT_T (index = index ) for index in indices ]
462488 return cls (x .shape , x .dtype , fragments )
463489
490+ @classmethod
491+ def like (
492+ cls : type [FS ],
493+ fragments : _GenericFragments [Any ],
494+ value : Literal [None ] = None ,
495+ ) -> FS :
496+ del value
497+ return cls (
498+ shape = fragments .shape ,
499+ dtype = fragments .dtype ,
500+ fragments = [
501+ cls .FRAGMENT_T (index = f .index ) for f in fragments .fragments
502+ ],
503+ )
504+
464505
465506@dataclasses .dataclass (frozen = True , init = False )
466- class NpFragments (_GenericFragments [NpFragment ]):
467- """A collection of fragments whose values are of type `np.ndarray` ."""
468- FRAGMENT_T = NpFragment
507+ class _ConcreteFragments (_GenericFragments [Fconcrete ]):
508+ """A collection of concrete fragments ."""
509+ FRAGMENT_T : ClassVar [ type [ Fconcrete ]] # The type of fragment values.
469510
470511 @classmethod
471512 def _of (cls : type [FS ], x : Any , * , indices : Sequence [Index ]) -> FS :
@@ -474,19 +515,37 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
474515 fragments = [cls .FRAGMENT_T (index = i , value = x [i ]) for i in indices ]
475516 return cls (x .shape , x .dtype , fragments )
476517
518+ @classmethod
519+ def like (
520+ cls : type [FS ], fragments : _GenericFragments [Any ], value : Aconcrete
521+ ) -> FS :
522+ _check_fragment_value_type (value , cls .FRAGMENT_T .ARRAY_T )
523+ if fragments .shape != value .shape or fragments .dtype != value .dtype :
524+ raise ValueError (
525+ f'Fragments type { fragments .dtype } [{ fragments .shape } ] does'
526+ f' not match value type { value .dtype } [{ value .shape } ].'
527+ )
528+ return cls (
529+ shape = fragments .shape ,
530+ dtype = fragments .dtype ,
531+ fragments = [
532+ cls .FRAGMENT_T (index = f .index , value = value [f .index ])
533+ for f in fragments .fragments
534+ ],
535+ )
536+
477537
478538@dataclasses .dataclass (frozen = True , init = False )
479- class JaxFragments (_GenericFragments [JaxFragment ]):
539+ class NpFragments (_ConcreteFragments [NpFragment ]):
540+ """A collection of fragments whose values are of type `np.ndarray`."""
541+ FRAGMENT_T = NpFragment
542+
543+
544+ @dataclasses .dataclass (frozen = True , init = False )
545+ class JaxFragments (_ConcreteFragments [JaxFragment ]):
480546 """A collection of fragments whose values are of type `jax.Array`."""
481547 FRAGMENT_T = JaxFragment
482548
483- @classmethod
484- def _of (cls : type [FS ], x : Any , * , indices : Sequence [Index ]) -> FS :
485- """Returns a Fragments with one fragment for each index."""
486- _check_fragment_value_type (x , cls .FRAGMENT_T .ARRAY_T )
487- fragments = [cls .FRAGMENT_T (index = i , value = x [i ]) for i in indices ]
488- return cls (x .shape , x .dtype , fragments )
489-
490549
491550# Extra names for backwards compatibility. Most loading and saving code still
492551# wants to deal with NumPy arrays so that views and operations on them
@@ -534,6 +593,28 @@ def addressable_shards(x: jax.Array | jax.ShapeDtypeStruct) -> list[Index]:
534593 ]
535594
536595
596+ def _gen_distinct_addressable_indices (
597+ x : np .ndarray | jax .Array | jax .ShapeDtypeStruct ,
598+ ) -> Generator [Index , None , None ]:
599+ """Yields fragment indices for distinct addressable shards of x."""
600+ match x :
601+ case jax .Array () | jax .ShapeDtypeStruct ():
602+ if not x .sharding :
603+ raise ValueError (
604+ 'Cannot determine addressable shards of jax.ShapeDtypeStruct with'
605+ ' no sharding.'
606+ )
607+ indices = addressable_shards (x )
608+ case np .ndarray ():
609+ indices = (tuple (slice (0 , dim , 1 ) for dim in x .shape ),)
610+ case _:
611+ raise TypeError (f'Unsupported type: { type (x )} ' )
612+ distinct_indices = sorted ({
613+ * (np_utils .to_hashable_index (i , shape = x .shape ) for i in indices )
614+ })
615+ yield from (np_utils .from_hashable_index (i ) for i in distinct_indices )
616+
617+
537618def abstract_fragments (
538619 x : jax .Array | jax .ShapeDtypeStruct | FS ,
539620) -> AbstractFragments :
0 commit comments