-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathelastic.py
More file actions
255 lines (201 loc) · 7.56 KB
/
elastic.py
File metadata and controls
255 lines (201 loc) · 7.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Elasticity manager.
This class provides a utility for elastic training. It provides a decorator that
retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down
events. It also provides a utility for waiting for slices to become active.
"""
import collections
from collections.abc import Mapping, Sequence, Set
import logging
import time
import jax
import numpy as np
from pathwaysutils.debug import timing
_logger = logging.getLogger(__name__)
_SIMPLE_EXECUTION_TEST_VALUE = 100
_ELASTIC_DOWN_ERROR_TYPES = frozenset([
"DATA_LOSS",
])
_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = frozenset([
"DEADLINE_EXCEEDED",
"NOT_FOUND",
"INTERNAL",
])
def _plus_one(x: jax.Array) -> jax.Array:
"""Adds one to each element in the array.
Used to test if a slice is active.
Args:
x: The array to add one to.
Returns:
The array with one added to each element.
"""
return x + 1
def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:
"""Simple execution to test if a slice is active.
This function is used to test if a slice is active. It executes a simple
computation on the devices and returns the result. If any of the devices are
not active, the returned array will fail with a JaxRuntimeError used.
Simply executing this function is not enough to determine if the slice is
active. We also need to check the value of the returned array.
Args:
devices: The devices to execute on.
Returns:
The result of the execution.
"""
if not devices:
raise ValueError("No devices")
test_input = np.zeros(len(devices), dtype=float) + (
_SIMPLE_EXECUTION_TEST_VALUE - 1
)
return jax.pmap(_plus_one, devices=devices)(test_input)
def get_slice_to_devices(
devices: Sequence[jax.Device],
) -> Mapping[int, Sequence[jax.Device]]:
"""Returns the mapping from slice index to devices."""
slice_to_devices = collections.defaultdict(list)
for d in devices:
slice_to_devices[d.slice_index].append(d)
return dict(slice_to_devices)
@timing.timeit
def get_active_slice_indices(
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
) -> Set[int]:
"""Returns the set of active slices indices.
Args:
slice_to_devices: A mapping from slice index to devices. If None,
`get_slice_to_devices(jax.devices())` is used to gather all available
devices and group them by slice.
Returns:
A set of integers representing the indices of the active slices.
"""
if slice_to_devices is None:
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))
active_slice_indices = set()
results = {
slice_index: _simple_execution(devices)
for slice_index, devices in slice_to_devices.items()
}
for slice_index, x in results.items():
_logger.info("Checking slice_index=%s", slice_index)
expected = (
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
+ _SIMPLE_EXECUTION_TEST_VALUE
)
try:
with timing.Timer(f"Checking {slice_index=}"):
jax.block_until_ready(x)
if np.allclose(x, expected):
active_slice_indices.add(slice_index)
_logger.info("slice_index=%s active", slice_index)
else:
_logger.error(
"Error with _simple_execution for slice_index=%s. "
"This should never happen. Expected: %r, Actual: %r",
slice_index,
expected,
x,
)
raise ValueError(
f"Error with _simple_execution for slice_index={slice_index}."
)
except jax.errors.JaxRuntimeError as error:
if not is_error_due_to_slice_down(error):
raise
_logger.info("slice_index=%s bad", slice_index)
_logger.info("active_slice_indices=%s", active_slice_indices)
return active_slice_indices
def wait_for_slices(
slice_count: int,
poll_interval: float | int = 10,
timeout: float | int | None = None,
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
) -> Set[int]:
"""Waits until after at least `slice_count` slices become active.
Args:
slice_count: The number of slices to wait for.
poll_interval: The minimum number of seconds to wait between availability
checks. If the check takes longer than this, the next check will start
immediately after the current check completes. Defaults to 10 seconds.
timeout: The maximum number of seconds to wait. If None, there is no
timeout.
slice_to_devices: A mapping from slice index to devices. If None,
`get_slice_to_devices(jax.devices())` is used.
Returns:
The active slice indices
Raises:
TimeoutError: If the timeout is reached before the slices become
active.
"""
if slice_to_devices is None:
slice_to_devices = get_slice_to_devices(jax.devices())
start_time = time.time()
while True:
check_start_time = time.time()
active_slice_indices = get_active_slice_indices(slice_to_devices)
if len(active_slice_indices) >= slice_count:
_logger.info("%s slices active.", len(active_slice_indices))
return active_slice_indices
_logger.info(
"%s slices active. Wanting at least %s.",
len(active_slice_indices),
slice_count,
)
time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
if timeout is not None:
elapsed_time = time.time() - start_time
if elapsed_time + time_to_sleep >= timeout:
raise TimeoutError(
f"Timed out waiting for {slice_count} slices. Only"
f" {len(active_slice_indices)} active after"
f" {elapsed_time:.2f} seconds."
f" Next check would occur after the timeout of {timeout}"
" seconds."
)
if time_to_sleep > 0:
_logger.info("Sleeping for %.2f seconds.", time_to_sleep)
time.sleep(time_to_sleep)
def is_error_due_to_slice_down(error: Exception) -> bool:
"""Returns True if the error is due to slice down.
The error types that are considered due to slice down are
jax.errors.JaxRuntimeError with the following error kind in the message:
- DATA_LOSS
- DEADLINE_EXCEEDED
- NOT_FOUND
- INTERNAL
Args:
error: The error to check.
"""
error_due_to_slice_down = False
traceback_logging_level = logging.DEBUG
if isinstance(error, jax.errors.JaxRuntimeError):
if any(
error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES
):
_logger.info("Caught an error due to slice down")
error_due_to_slice_down = True
elif any(
error_type in str(error)
for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
):
_logger.warning(
"Caught an error due that may or may not be due to slice down. This"
" error will be treated as due to slice down."
)
traceback_logging_level = logging.WARNING
error_due_to_slice_down = True
if not error_due_to_slice_down:
_logger.info("Caught an error not due to slice down")
_logger.log(traceback_logging_level, "Error details:", exc_info=True)
return error_due_to_slice_down