@@ -56,7 +56,7 @@ def __init__(
5656 ):
5757 def ifrt_hlo_sharding (
5858 aval : jax .core .ShapedArray , sharding : jax .sharding .Sharding
59- ) -> dict [str , Any ]:
59+ ) -> Mapping [str , Any ]:
6060 result = {
6161 "devices" : {
6262 "device_ids" : [
@@ -190,7 +190,9 @@ class NoIntermediateShardingNeededError(NoIntermediateShardingError):
190190 """Raised when no intermediate sharding is needed for optimization."""
191191
192192
193- def _get_sharding_spec_dims (sharding : jax .sharding .NamedSharding ) -> list [int ]:
193+ def _get_sharding_spec_dims (
194+ sharding : jax .sharding .NamedSharding ,
195+ ) -> Sequence [int ]:
194196 """Gets the sharding dimension sizes from a NamedSharding."""
195197 mesh = sharding .mesh
196198 dims = []
@@ -244,7 +246,7 @@ def _get_split_candidates(
244246 src_dims : Sequence [int ],
245247 dst_dims : Sequence [int ],
246248 gcd_shards : Sequence [int ],
247- ) -> list [tuple [int , str ]]:
249+ ) -> Sequence [tuple [int , str ]]:
248250 """Finds dimensions that are candidates for splitting."""
249251 split_candidates = []
250252 for i , spec in enumerate (in_sharding .spec ):
@@ -271,8 +273,8 @@ def _build_intermediate_mesh_and_spec(
271273 in_spec : jax .sharding .PartitionSpec ,
272274 src_dims : Sequence [int ],
273275 dst_dims : Sequence [int ],
274- split_candidates : list [tuple [int , str ]],
275- ) -> tuple [jax .sharding .Mesh , jax .sharding .PartitionSpec , list [str ]]:
276+ split_candidates : Sequence [tuple [int , str ]],
277+ ) -> tuple [jax .sharding .Mesh , jax .sharding .PartitionSpec , Sequence [str ]]:
276278 """Builds the intermediate Mesh and PartitionSpec."""
277279 # Build a map of mesh axis to split information: (dim_idx, replicas)
278280 mesh_axis_to_split_info = {}
@@ -321,7 +323,7 @@ def _build_intermediate_mesh_and_spec(
321323
322324def find_intermediate_sharding (
323325 in_sharding : jax .sharding .Sharding , out_sharding : jax .sharding .Sharding
324- ) -> tuple [jax .sharding .NamedSharding , list [str ]]:
326+ ) -> tuple [jax .sharding .NamedSharding , Sequence [str ]]:
325327 """Finds an intermediate sharding to reshard to before target sharding.
326328
327329 This function tries to find an intermediate sharding that can be used to
@@ -345,7 +347,8 @@ def find_intermediate_sharding(
345347 Returns:
346348 A tuple containing:
347349 - An intermediate sharding.
348- - A list of axis names that are replicated in the intermediate sharding.
350+ - A sequence of axis names that are replicated in the intermediate
351+ sharding.
349352
350353 Raises:
351354 NoIntermediateShardingError: If no intermediate sharding is found.
0 commit comments