Skip to content

Commit 2c8593c

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Use abstract types for return type hints in pathwaysutils.
This change replaces concrete types like `dict`, `set`, and `list` with their abstract counterparts `Mapping`, `Set`, and `Sequence` in function signatures and class attributes across `pathwaysutils`. This improves type hint flexibility and adheres to Python best practices. PiperOrigin-RevId: 884182804
1 parent 83d0aa3 commit 2c8593c

7 files changed

Lines changed: 33 additions & 29 deletions

File tree

pathwaysutils/elastic/elastic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
import collections
22-
from collections.abc import Mapping, Sequence
22+
from collections.abc import Mapping, Sequence, Set
2323
import logging
2424
import time
2525

@@ -83,7 +83,7 @@ def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:
8383

8484
def get_slice_to_devices(
8585
devices: Sequence[jax.Device],
86-
) -> dict[int, Sequence[jax.Device]]:
86+
) -> Mapping[int, Sequence[jax.Device]]:
8787
"""Returns the mapping from slice index to devices."""
8888
slice_to_devices = collections.defaultdict(list)
8989
for d in devices:
@@ -94,7 +94,7 @@ def get_slice_to_devices(
9494
@timing.timeit
9595
def get_active_slice_indices(
9696
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
97-
) -> set[int]:
97+
) -> Set[int]:
9898
"""Returns the set of active slices indices.
9999
100100
Args:
@@ -153,7 +153,7 @@ def wait_for_slices(
153153
poll_interval: float | int = 10,
154154
timeout: float | int | None = None,
155155
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
156-
) -> set[int]:
156+
) -> Set[int]:
157157
"""Waits until after at least `slice_count` slices become active.
158158
159159
Args:

pathwaysutils/elastic/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
events. It also provides a utility for waiting for slices to become active.
1919
"""
2020

21-
from collections.abc import Callable, Mapping, Sequence
21+
from collections.abc import Callable, Mapping, Sequence, Set
2222
import functools
2323
import logging
2424
from typing import Any, TypeVar
@@ -58,7 +58,7 @@ class Manager:
5858

5959
_total_slice_count: int | None = None
6060
slice_to_devices: Mapping[int, Sequence[jax.Device]]
61-
active_slice_indices: set[int]
61+
active_slice_indices: Set[int]
6262

6363
def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
6464
"""Initializes the manager.

pathwaysutils/experimental/profiling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414
"""Experimental profiling utilites."""
1515

16+
from collections.abc import Mapping
1617
from typing import Any
1718

1819
from pathwaysutils import profiling
1920

2021

2122
def start_trace(
22-
profile_request: dict[str, Any],
23+
profile_request: Mapping[str, Any],
2324
*,
2425
create_perfetto_link: bool = False,
2526
create_perfetto_trace: bool = False,
@@ -33,7 +34,7 @@ def start_trace(
3334
Use `jax.profiler.stop_trace` to end profiling.
3435
3536
Args:
36-
profile_request: A dictionary containing the profile request options.
37+
profile_request: A mapping containing the profile request options.
3738
create_perfetto_link: A boolean which, if true, creates and prints link to
3839
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
3940
block until the link is opened and Perfetto loads the trace. This feature

pathwaysutils/experimental/reshard.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ def __init__(
5454
destination_shardings: Sequence[jax.sharding.Sharding],
5555
donate: bool,
5656
):
57+
5758
def ifrt_hlo_sharding(
5859
aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding
59-
) -> dict[str, Any]:
60+
) -> Mapping[str, Any]:
6061
result = {
6162
"devices": {
6263
"device_ids": [
@@ -190,7 +191,9 @@ class NoIntermediateShardingNeededError(NoIntermediateShardingError):
190191
"""Raised when no intermediate sharding is needed for optimization."""
191192

192193

193-
def _get_sharding_spec_dims(sharding: jax.sharding.NamedSharding) -> list[int]:
194+
def _get_sharding_spec_dims(
195+
sharding: jax.sharding.NamedSharding,
196+
) -> Sequence[int]:
194197
"""Gets the sharding dimension sizes from a NamedSharding."""
195198
mesh = sharding.mesh
196199
dims = []
@@ -244,7 +247,7 @@ def _get_split_candidates(
244247
src_dims: Sequence[int],
245248
dst_dims: Sequence[int],
246249
gcd_shards: Sequence[int],
247-
) -> list[tuple[int, str]]:
250+
) -> Sequence[tuple[int, str]]:
248251
"""Finds dimensions that are candidates for splitting."""
249252
split_candidates = []
250253
for i, spec in enumerate(in_sharding.spec):
@@ -271,8 +274,8 @@ def _build_intermediate_mesh_and_spec(
271274
in_spec: jax.sharding.PartitionSpec,
272275
src_dims: Sequence[int],
273276
dst_dims: Sequence[int],
274-
split_candidates: list[tuple[int, str]],
275-
) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, list[str]]:
277+
split_candidates: Sequence[tuple[int, str]],
278+
) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, Sequence[str]]:
276279
"""Builds the intermediate Mesh and PartitionSpec."""
277280
# Build a map of mesh axis to split information: (dim_idx, replicas)
278281
mesh_axis_to_split_info = {}
@@ -321,7 +324,7 @@ def _build_intermediate_mesh_and_spec(
321324

322325
def find_intermediate_sharding(
323326
in_sharding: jax.sharding.Sharding, out_sharding: jax.sharding.Sharding
324-
) -> tuple[jax.sharding.NamedSharding, list[str]]:
327+
) -> tuple[jax.sharding.NamedSharding, Sequence[str]]:
325328
"""Finds an intermediate sharding to reshard to before target sharding.
326329
327330
This function tries to find an intermediate sharding that can be used to
@@ -343,9 +346,9 @@ def find_intermediate_sharding(
343346
out_sharding: The target sharding.
344347
345348
Returns:
346-
A tuple containing:
347-
- An intermediate sharding.
348-
- A list of axis names that are replicated in the intermediate sharding.
349+
A tuple (intermediate_sharding, replicated_axes), where
350+
replicated_axes is a sequence of axis names that are replicated in the
351+
intermediate sharding.
349352
350353
Raises:
351354
NoIntermediateShardingError: If no intermediate sharding is found.

pathwaysutils/persistence/helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Helper functions for persistence."""
1515

1616
import base64
17-
from collections.abc import Sequence
17+
from collections.abc import Mapping, Sequence
1818
import concurrent.futures
1919
import datetime
2020
import json
@@ -94,7 +94,7 @@ def get_hlo_sharding_string(
9494
def get_shape_info(
9595
dtype: np.dtype,
9696
dimensions: Sequence[int],
97-
) -> dict[str, Sequence[int] | str]:
97+
) -> Mapping[str, Sequence[int] | str]:
9898
"""Returns shape info in the format expected by read requests."""
9999
return {
100100
"xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype),
@@ -108,7 +108,7 @@ def get_write_request(
108108
jax_array: jax.Array,
109109
timeout: datetime.timedelta,
110110
return_dict: bool = False,
111-
) -> str | dict[str, Any]:
111+
) -> str | Mapping[str, Any]:
112112
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
113113
sharding = jax_array.sharding
114114
assert isinstance(sharding, jax.sharding.Sharding), sharding
@@ -172,7 +172,7 @@ def get_read_request(
172172
devices: Sequence[jax.Device],
173173
timeout: datetime.timedelta,
174174
return_dict: bool = False,
175-
) -> str | dict[str, Any]:
175+
) -> str | Mapping[str, Any]:
176176
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
177177
if not isinstance(devices, np.ndarray):
178178
devices = np.array(devices)

pathwaysutils/persistence/orbax_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def serialize(
9696
values: Sequence[jax.Array],
9797
infos: Sequence[ParamInfo],
9898
args: Sequence[SaveArgs] | None = None,
99-
) -> list[future.Future]:
99+
) -> Sequence[future.Future]:
100100
"""Uses Pathways Persistence API to serialize a jax array."""
101101
type_handlers.check_input_arguments(values, infos, args)
102102

@@ -158,7 +158,7 @@ async def deserialize(
158158
self,
159159
infos: Sequence[ParamInfo],
160160
args: Sequence[RestoreArgs] | None = None,
161-
) -> list[jax.Array]:
161+
) -> Sequence[jax.Array]:
162162
"""Uses Pathways Persistence API to deserialize a jax array."""
163163
if args is None:
164164
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")

pathwaysutils/profiling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
import threading
22-
from typing import Any
22+
from typing import Any, Mapping
2323
import urllib.parse
2424

2525
import fastapi
@@ -59,23 +59,23 @@ def toy_computation() -> None:
5959

6060
def _create_profile_request(
6161
log_dir: os.PathLike[str] | str,
62-
) -> dict[str, Any]:
63-
"""Creates a profile request dictionary from the given options."""
62+
) -> Mapping[str, Any]:
63+
"""Creates a profile request mapping from the given options."""
6464
profile_request = {}
6565
profile_request["traceLocation"] = str(log_dir)
6666

6767
return profile_request
6868

6969

7070
def _start_pathways_trace_from_profile_request(
71-
profile_request: dict[str, Any],
71+
profile_request: Mapping[str, Any],
7272
) -> None:
7373
"""Starts a profiler trace on Pathways components from a profile request.
7474
7575
This will only profile the Pathways components and not the JAX client code.
7676
7777
Args:
78-
profile_request: A dictionary containing the profile request options.
78+
profile_request: A mapping containing the profile request options.
7979
"""
8080
with _profile_state.lock:
8181
global _first_profile_start
@@ -191,7 +191,7 @@ class ProfilingConfig:
191191
repository_path: str
192192

193193
@app.post("/profiling")
194-
async def profiling(pc: ProfilingConfig) -> dict[str, str]:
194+
async def profiling(pc: ProfilingConfig) -> Mapping[str, str]:
195195
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
196196
_logger.debug("Writing profiling data to %s", pc.repository_path)
197197
await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)

0 commit comments

Comments
 (0)