@@ -54,9 +54,10 @@ def __init__(
5454 destination_shardings : Sequence [jax .sharding .Sharding ],
5555 donate : bool ,
5656 ):
57+
5758 def ifrt_hlo_sharding (
5859 aval : jax .core .ShapedArray , sharding : jax .sharding .Sharding
59- ) -> dict [str , Any ]:
60+ ) -> Mapping [str , Any ]:
6061 result = {
6162 "devices" : {
6263 "device_ids" : [
@@ -190,7 +191,9 @@ class NoIntermediateShardingNeededError(NoIntermediateShardingError):
190191 """Raised when no intermediate sharding is needed for optimization."""
191192
192193
193- def _get_sharding_spec_dims (sharding : jax .sharding .NamedSharding ) -> list [int ]:
194+ def _get_sharding_spec_dims (
195+ sharding : jax .sharding .NamedSharding ,
196+ ) -> Sequence [int ]:
194197 """Gets the sharding dimension sizes from a NamedSharding."""
195198 mesh = sharding .mesh
196199 dims = []
@@ -244,7 +247,7 @@ def _get_split_candidates(
244247 src_dims : Sequence [int ],
245248 dst_dims : Sequence [int ],
246249 gcd_shards : Sequence [int ],
247- ) -> list [tuple [int , str ]]:
250+ ) -> Sequence [tuple [int , str ]]:
248251 """Finds dimensions that are candidates for splitting."""
249252 split_candidates = []
250253 for i , spec in enumerate (in_sharding .spec ):
@@ -271,8 +274,8 @@ def _build_intermediate_mesh_and_spec(
271274 in_spec : jax .sharding .PartitionSpec ,
272275 src_dims : Sequence [int ],
273276 dst_dims : Sequence [int ],
274- split_candidates : list [tuple [int , str ]],
275- ) -> tuple [jax .sharding .Mesh , jax .sharding .PartitionSpec , list [str ]]:
277+ split_candidates : Sequence [tuple [int , str ]],
278+ ) -> tuple [jax .sharding .Mesh , jax .sharding .PartitionSpec , Sequence [str ]]:
276279 """Builds the intermediate Mesh and PartitionSpec."""
277280 # Build a map of mesh axis to split information: (dim_idx, replicas)
278281 mesh_axis_to_split_info = {}
@@ -321,7 +324,7 @@ def _build_intermediate_mesh_and_spec(
321324
322325def find_intermediate_sharding (
323326 in_sharding : jax .sharding .Sharding , out_sharding : jax .sharding .Sharding
324- ) -> tuple [jax .sharding .NamedSharding , list [str ]]:
327+ ) -> tuple [jax .sharding .NamedSharding , Sequence [str ]]:
325328 """Finds an intermediate sharding to reshard to before target sharding.
326329
327330 This function tries to find an intermediate sharding that can be used to
@@ -343,9 +346,9 @@ def find_intermediate_sharding(
343346 out_sharding: The target sharding.
344347
345348 Returns:
346- A tuple containing:
347- - An intermediate sharding.
348- - A list of axis names that are replicated in the intermediate sharding.
349+ A tuple (intermediate_sharding, replicated_axes), where
350+ replicated_axes is a sequence of axis names that are replicated in the
351+ intermediate sharding.
349352
350353 Raises:
351354 NoIntermediateShardingError: If no intermediate sharding is found.
0 commit comments