Skip to content

Commit c2ea93d

Browse files
author
Orbax Authors
committed
Internal change
PiperOrigin-RevId: 868740843
1 parent ab37bb0 commit c2ea93d

2 files changed

Lines changed: 0 additions & 69 deletions

File tree

checkpoint/orbax/checkpoint/_src/arrays/fragments.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -615,22 +615,6 @@ def _gen_distinct_addressable_indices(
615615
yield from (np_utils.from_hashable_index(i) for i in distinct_indices)
616616

617617

618-
def abstract_fragments(
619-
x: jax.Array | jax.ShapeDtypeStruct | FS,
620-
) -> AbstractFragments:
621-
"""Returns abstract fragments matching the given object."""
622-
if isinstance(x, AbstractFragments):
623-
return x
624-
else:
625-
if isinstance(x, _GenericFragments):
626-
indices = (fragment.index for fragment in x.fragments)
627-
else:
628-
indices = addressable_shards(x)
629-
return AbstractFragments(x.shape, x.dtype, [
630-
AbstractFragment(index=index) for index in indices
631-
])
632-
633-
634618
def validate_fragments_can_be_stacked(fragments: FSconcrete) -> None:
635619
"""Validates that the given fragments can be stacked."""
636620
if not fragments.fragments:

checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,59 +1118,6 @@ def test_unsharded_array_is_fully_replicated(self):
11181118
)
11191119

11201120

1121-
class AbstractFragmentsTest(parameterized.TestCase):
1122-
1123-
def test_returns_abstract_fragments_instance_itself(
1124-
self
1125-
):
1126-
fragments = AbstractFragments(
1127-
shape=(2, 3), dtype=np.dtype(np.float32), fragments=[]
1128-
)
1129-
self.assertIs(fragments, array_fragments.abstract_fragments(fragments))
1130-
1131-
@parameterized.named_parameters(
1132-
('np_fragments', NpFragments),
1133-
('jax_fragments', JaxFragments),
1134-
)
1135-
def test_converts_concrete_fragments(
1136-
self, fragments_t: ConcreteFragmentsT
1137-
):
1138-
fragment_t = fragments_t.FRAGMENT_T
1139-
np_api = fragment_t.NP_API
1140-
concrete_fragments = fragments_t(
1141-
shape=(2, 3),
1142-
dtype=np.dtype(np.float32),
1143-
fragments=[
1144-
fragment_t(index=np.s_[0:2:1, 0:3:1], value=np_api.arange(6)),
1145-
],
1146-
)
1147-
expected_abstract_fragments = AbstractFragments(
1148-
shape=(2, 3),
1149-
dtype=np.dtype(np.float32),
1150-
fragments=[
1151-
AbstractFragment(
1152-
index=np.s_[0:2:1, 0:3:1], value=None
1153-
),
1154-
],
1155-
)
1156-
self.assertEqual(
1157-
expected_abstract_fragments,
1158-
array_fragments.abstract_fragments(concrete_fragments),
1159-
)
1160-
1161-
def test_converts_fully_replicated_shape_dtype_struct(self):
1162-
self.assertEqual(
1163-
AbstractFragments(
1164-
shape=(4, 5),
1165-
dtype=np.dtype(np.float32),
1166-
fragments=[AbstractFragment(index=np.s_[0:4:1, 0:5:1])],
1167-
),
1168-
array_fragments.abstract_fragments(
1169-
jax.ShapeDtypeStruct((4, 5), np.dtype(np.float32))
1170-
),
1171-
)
1172-
1173-
11741121
class StackFragmentsTest(parameterized.TestCase):
11751122

11761123
@parameterized.named_parameters(

0 commit comments

Comments
 (0)