-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmanager.py
More file actions
221 lines (183 loc) · 7.28 KB
/
manager.py
File metadata and controls
221 lines (183 loc) · 7.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Elasticity manager.
This class provides a utility for elastic training. It provides a decorator that
retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down
events. It also provides a utility for waiting for slices to become active.
"""
from collections.abc import Callable, Mapping, Sequence, Set
import functools
import logging
from typing import Any, TypeVar
import jax
from pathwaysutils.elastic import elastic
_logger = logging.getLogger(__name__)
class ElasticRuntimeError(RuntimeError):
"""Error raised when elasticity cannot continue."""
_F = TypeVar("_F", bound=Callable[..., Any])
def _elastic_event_cleanup() -> None:
"""Cleans up JAX profiles, caches, and live arrays."""
try:
_logger.info("Cleaning up any ongoing traces")
jax.profiler.stop_trace()
except (RuntimeError, ValueError) as e:
_logger.info("No ongoing traces to clean up")
except Exception:
_logger.exception("Error cleaning up ongoing traces")
raise
jax.clear_caches()
for array in jax.live_arrays():
array.delete()
class Manager:
"""Utility class for elastic training."""
_total_slice_count: int | None = None
slice_to_devices: Mapping[int, Sequence[jax.Device]]
active_slice_indices: Set[int]
def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
"""Initializes the manager.
Args:
devices: The devices to use. If None, jax.devices() is used.
"""
if devices is None:
devices = jax.devices()
self.slice_to_devices = elastic.get_slice_to_devices(devices)
self.active_slice_indices = elastic.get_active_slice_indices(
slice_to_devices=self.slice_to_devices
)
@property
def total_slice_count(self) -> int:
"""Returns the total number of slices."""
if self._total_slice_count is None:
self._total_slice_count = len(self.slice_to_devices)
return self._total_slice_count
@property
def default_device(self) -> jax.Device:
"""Returns the device that should be set to the default device.
This will be from one of the slices in `active_slice_indices`.
"""
try:
return self.slice_to_devices[next(iter(self.active_slice_indices))][0]
except StopIteration as error:
raise ValueError("No active slices") from error
@property
def active_slice_count(self) -> int:
"""Returns the number of slices."""
return len(self.active_slice_indices)
def scale_by_active_slices(self, x: int | float) -> int | float:
"""Scale x by the number of active slices."""
if isinstance(x, int):
quotient, remainder = divmod(
x * self.active_slice_count, self.total_slice_count
)
if remainder:
raise ValueError(
f"Cannot scale {x=} by active slices because it will result in a "
f"remainder of {remainder=}."
)
return quotient
elif isinstance(x, float):
return x * self.active_slice_count / self.total_slice_count
else:
raise ValueError(f"Unsupported type: {type(x)=}")
def _elasticity_retry_decorator(
self,
max_retries: int,
pre_callback: Callable[..., Any] | None = None,
on_elastic_event_callback: Callable[..., Any] | None = None,
) -> Callable[[_F], _F]:
"""Retries a function with elasticity fault tolerance.
Args:
max_retries: The maximum number of times to retry the function.
pre_callback: A callback to call before each attempt of the wrapped
function.
on_elastic_event_callback: A callback to call after an elastic failure
occurs.
Returns:
A function decorator.
"""
if max_retries <= 0:
raise ValueError("max_retries must be positive.")
def decorator(func: _F) -> _F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for retry_index in range(max_retries):
try:
_logger.info(
"Elastic attempt %d out of %d", retry_index + 1, max_retries
)
if pre_callback is not None:
pre_callback()
with jax.default_device(self.default_device):
return func(*args, **kwargs)
except jax.errors.JaxRuntimeError as error:
if not elastic.is_error_due_to_slice_down(error):
raise
_elastic_event_cleanup()
if on_elastic_event_callback is not None:
on_elastic_event_callback()
else:
raise ElasticRuntimeError(
f"Elastic attempt {max_retries} out of {max_retries} failed."
)
return wrapper
return decorator
def pause_resume(
self,
max_retries: int,
poll_interval: float | int = 10,
timeout: float | None = None,
pre_callback: Callable[..., Any] | None = None,
on_elastic_event_callback: Callable[..., Any] | None = None,
) -> Callable[[_F], _F]:
"""Retries a function with pause/resume fault tolerance.
This decorator wraps a function to automatically retry execution in case of
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
active slices before each attempt and cleans up JAX caches on failure.
The function will not be attempted (or reattempted) until all of the slices
are active.
Often, the function will dispatch JAX operations and wait for them to
complete while creating a log message. If using Python logging, it is
recommended to set `logging.raiseExceptions=True` to ensure that the
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
call.
Args:
max_retries: The maximum number of times to retry the function.
poll_interval: The number of seconds to wait between activity checks.
Defaults to 10 seconds.
timeout: The maximum number of seconds to wait for slices to become
active before each retry attempt. If None, there is no timeout.
pre_callback: A callback to call before the function is attempted.
on_elastic_event_callback: A callback to call after an elastic failure
occurs.
Returns:
The result of the wrapped function.
Raises:
ElasticRuntimeError: If all retry attempts fail.
Exception: Any other exception raised by the wrapped function that is not
due to a slice down event.
"""
def internal_pre_callback():
self.active_slice_indices = elastic.wait_for_slices(
slice_count=self.total_slice_count,
slice_to_devices=self.slice_to_devices,
poll_interval=poll_interval,
timeout=timeout,
)
if pre_callback is not None:
pre_callback()
return self._elasticity_retry_decorator(
max_retries=max_retries,
pre_callback=internal_pre_callback,
on_elastic_event_callback=on_elastic_event_callback,
)