Skip to content

Commit 73803ac

Browse files
lukebaumanncopybara-github
authored andcommitted
Add a replica_resize decorator for fault tolerance in elastic Pathways.
PiperOrigin-RevId: 857246882
1 parent b752db7 commit 73803ac

3 files changed

Lines changed: 489 additions & 352 deletions

File tree

pathwaysutils/elastic/elastic.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Elasticity manager.
15+
16+
This class provides a utility for elastic training. It provides a decorator that
17+
retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down
18+
events. It also provides a utility for waiting for slices to become active.
19+
"""
20+
21+
import collections
22+
from collections.abc import Mapping, Sequence
23+
import logging
24+
import time
25+
26+
import jax
27+
import numpy as np
28+
from pathwaysutils.debug import timing
29+
30+
31+
_logger = logging.getLogger(__name__)
32+
33+
_SIMPLE_EXECUTION_TEST_VALUE = 100
34+
_ELASTIC_DOWN_ERROR_TYPES = frozenset(
35+
"DATA_LOSS",
36+
)
37+
_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = frozenset(
38+
"DEADLINE_EXCEEDED",
39+
"NOT_FOUND",
40+
"INTERNAL",
41+
)
42+
43+
44+
def _plus_one(x: jax.Array) -> jax.Array:
45+
"""Adds one to each element in the array.
46+
47+
Used to test if a slice is active.
48+
49+
Args:
50+
x: The array to add one to.
51+
52+
Returns:
53+
The array with one added to each element.
54+
"""
55+
return x + 1
56+
57+
58+
def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:
59+
"""Simple execution to test if a slice is active.
60+
61+
This function is used to test if a slice is active. It executes a simple
62+
computation on the devices and returns the result. If any of the devices are
63+
not active, the returned array will fail with a JaxRuntimeError used.
64+
65+
Simply executing this function is not enough to determine if the slice is
66+
active. We also need to check the value of the returned array.
67+
68+
Args:
69+
devices: The devices to execute on.
70+
71+
Returns:
72+
The result of the execution.
73+
"""
74+
if not devices:
75+
raise ValueError("No devices")
76+
77+
test_input = np.zeros(len(devices), dtype=float) + (
78+
_SIMPLE_EXECUTION_TEST_VALUE - 1
79+
)
80+
81+
return jax.pmap(_plus_one, devices=devices)(test_input)
82+
83+
84+
def get_slice_to_devices(
85+
devices: Sequence[jax.Device],
86+
) -> dict[int, Sequence[jax.Device]]:
87+
"""Returns the mapping from slice index to devices."""
88+
slice_to_devices = collections.defaultdict(list)
89+
for d in devices:
90+
slice_to_devices[d.slice_index].append(d)
91+
return dict(slice_to_devices)
92+
93+
94+
@timing.timeit
95+
def get_active_slice_indices(
96+
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
97+
) -> set[int]:
98+
"""Returns the set of active slices indices.
99+
100+
Args:
101+
slice_to_devices: A mapping from slice index to devices. If None,
102+
`get_slice_to_devices(jax.devices())` is used to gather all available
103+
devices and group them by slice.
104+
105+
Returns:
106+
A set of integers representing the indices of the active slices.
107+
"""
108+
if slice_to_devices is None:
109+
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
110+
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))
111+
112+
_logger.debug(
113+
"Getting active slice indices for slices: %s",
114+
sorted(list(slice_to_devices.keys())),
115+
)
116+
117+
active_slice_indices = set()
118+
119+
results = {
120+
slice_index: _simple_execution(devices)
121+
for slice_index, devices in slice_to_devices.items()
122+
}
123+
124+
for slice_index, x in results.items():
125+
_logger.debug("Checking slice_index=%s", slice_index)
126+
expected = (
127+
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
128+
+ _SIMPLE_EXECUTION_TEST_VALUE
129+
)
130+
try:
131+
with timing.Timer(f"Checking {slice_index=}"):
132+
_logger.debug("Blocking until ready for slice_index=%s", slice_index)
133+
jax.block_until_ready(x)
134+
_logger.debug("Execution finished for slice_index=%s", slice_index)
135+
if np.allclose(x, expected):
136+
active_slice_indices.add(slice_index)
137+
_logger.debug("slice_index=%s active", slice_index)
138+
else:
139+
_logger.error(
140+
"Error with _simple_execution for slice_index=%s. "
141+
"This should never happen. Expected: %r, Actual: %r",
142+
slice_index,
143+
expected,
144+
x,
145+
)
146+
raise ValueError(
147+
f"Error with _simple_execution for slice_index={slice_index}."
148+
)
149+
except jax.errors.JaxRuntimeError as error:
150+
_logger.debug(
151+
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
152+
)
153+
if not is_error_due_to_slice_down(error):
154+
_logger.info("Re-raising error for slice_index=%s", slice_index)
155+
raise
156+
_logger.debug("slice_index=%s bad", slice_index)
157+
158+
_logger.debug("active_slice_indices=%s", active_slice_indices)
159+
160+
return active_slice_indices
161+
162+
163+
def wait_for_slices(
164+
slice_count: int,
165+
poll_interval: float | int = 10,
166+
timeout: float | int | None = None,
167+
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
168+
) -> set[int]:
169+
"""Waits until after at least `slice_count` slices become active.
170+
171+
Args:
172+
slice_count: The number of slices to wait for.
173+
poll_interval: The minimum number of seconds to wait between availability
174+
checks. If the check takes longer than this, the next check will start
175+
immediately after the current check completes. Defaults to 10 seconds.
176+
timeout: The maximum number of seconds to wait. If None, there is no
177+
timeout.
178+
slice_to_devices: A mapping from slice index to devices. If None,
179+
`get_slice_to_devices(jax.devices())` is used.
180+
181+
Returns:
182+
The active slice indices
183+
184+
Raises:
185+
TimeoutError: If the timeout is reached before the slices become
186+
active.
187+
"""
188+
if slice_to_devices is None:
189+
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
190+
slice_to_devices = get_slice_to_devices(jax.devices())
191+
192+
_logger.info(
193+
"Waiting for %s slices. Poll interval: %s, Timeout: %s",
194+
slice_count,
195+
poll_interval,
196+
timeout,
197+
)
198+
start_time = time.time()
199+
200+
while True:
201+
check_start_time = time.time()
202+
203+
_logger.debug("Checking active slices...")
204+
active_slice_indices = get_active_slice_indices(slice_to_devices)
205+
if len(active_slice_indices) >= slice_count:
206+
_logger.info(
207+
"Sufficient slices active: %s >= %s. Active indices: %s",
208+
len(active_slice_indices),
209+
slice_count,
210+
active_slice_indices,
211+
)
212+
return active_slice_indices
213+
214+
_logger.info(
215+
"%s slices active. Wanting at least %s. Active indices: %s",
216+
len(active_slice_indices),
217+
slice_count,
218+
active_slice_indices,
219+
)
220+
221+
time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
222+
223+
if timeout is not None:
224+
elapsed_time = time.time() - start_time
225+
if elapsed_time + time_to_sleep >= timeout:
226+
raise TimeoutError(
227+
f"Timed out waiting for {slice_count} slices. Only"
228+
f" {len(active_slice_indices)} active after"
229+
f" {elapsed_time:.2f} seconds."
230+
f" Next check would occur after the timeout of {timeout}"
231+
" seconds."
232+
)
233+
234+
if time_to_sleep > 0:
235+
_logger.debug("Sleeping for %.2f seconds.", time_to_sleep)
236+
237+
time.sleep(time_to_sleep)
238+
239+
240+
def is_error_due_to_slice_down(error: Exception) -> bool:
241+
"""Returns True if the error is due to slice down.
242+
243+
The error types that are considered due to slice down are
244+
jax.errors.JaxRuntimeError with the following error kind in the message:
245+
- DATA_LOSS
246+
- DEADLINE_EXCEEDED
247+
- NOT_FOUND
248+
- INTERNAL
249+
250+
Args:
251+
error: The error to check.
252+
"""
253+
error_due_to_slice_down = False
254+
traceback_logging_level = logging.DEBUG
255+
256+
if isinstance(error, jax.errors.JaxRuntimeError):
257+
_logger.debug("Checking if JaxRuntimeError is due to slice down: %s", error)
258+
if any(
259+
error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES
260+
):
261+
_logger.debug(
262+
"Caught an error due to slice down (matched"
263+
" _ELASTIC_DOWN_ERROR_TYPES)"
264+
)
265+
266+
error_due_to_slice_down = True
267+
268+
elif any(
269+
error_type in str(error)
270+
for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
271+
):
272+
_logger.warning(
273+
"Caught an error that may or may not be due to slice down (matched"
274+
" _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES). This error will be treated"
275+
" as due to slice down."
276+
)
277+
traceback_logging_level = logging.WARNING
278+
279+
error_due_to_slice_down = True
280+
281+
if not error_due_to_slice_down:
282+
_logger.debug("Caught an error not due to slice down")
283+
284+
_logger.log(traceback_logging_level, "Error details:", exc_info=True)
285+
286+
return error_due_to_slice_down

0 commit comments

Comments
 (0)