Skip to content

Commit 1afa2d6

Browse files
Add elastic pause/resume functionality to MaxText.
PiperOrigin-RevId: 890594825
1 parent 2a57a30 commit 1afa2d6

6 files changed

Lines changed: 401 additions & 2 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,3 +1193,9 @@ distill_temperature: 1.0
11931193
# 0.0 value disables this feature.
11941194
distill_beta: 0.0
11951195
distill_layer_indices: None
1196+
1197+
##### Elastic training parameters
1198+
# Elastic training is Pathways-specific and does not work on McJAX.
1199+
elastic_enabled: false
1200+
elastic_timeout_seconds: 300
1201+
elastic_max_retries: 10

src/maxtext/configs/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,26 @@ class Goodput(BaseModel):
15501550
enable_gcp_step_deviation_metrics: bool = Field(True, description="Enable GCP step deviation metrics.")
15511551

15521552

1553+
class ElasticTraining(BaseModel):
1554+
"""Configuration for elastic training and fault tolerance.
1555+
1556+
Elastic training is Pathways-specific and does not work on McJAX.
1557+
"""
1558+
1559+
elastic_enabled: bool = Field(False, description="Whether to enable elastic training.")
1560+
elastic_timeout_seconds: int = Field(
1561+
300,
1562+
description=(
1563+
"The maximum number of seconds to wait for `elastic_minimum_slice_count` slices to become active. If this"
1564+
" timeout is reached during any retry attempt, a `TimeoutError` is raised and training fails."
1565+
),
1566+
)
1567+
elastic_max_retries: int = Field(
1568+
10,
1569+
description="The maximum number of times to retry training when a slice failure occurs or when scaling up.",
1570+
)
1571+
1572+
15531573
class GcpMonitoring(BaseModel):
15541574
"""Configuration for GCP-specific workload monitoring."""
15551575

@@ -1947,6 +1967,7 @@ class MaxTextConfig(
19471967
Checkpointing,
19481968
OrbaxStorage,
19491969
EmergencyCheckpointing,
1970+
ElasticTraining,
19501971
# Data Types and Quantization
19511972
DataTypes,
19521973
Quantization,
@@ -2456,6 +2477,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24562477
# H. RUN ALL CROSS-FIELD VALIDATIONS
24572478
if self.load_parameters_path and self.load_full_state_path:
24582479
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
2480+
if self.elastic_enabled and not self.enable_single_controller:
2481+
raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).")
24592482
if (self.load_parameters_path or self.load_full_state_path) and not self.enable_checkpointing:
24602483
raise ValueError("You must set enable_checkpointing=True to load a checkpoint.")
24612484
if self.enable_multi_tier_checkpointing:

src/maxtext/trainers/pre_train/train.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from maxtext.configs import pyconfig
4242
from maxtext.common.common_types import ShardMode
4343
from maxtext.utils.globals import EPS
44+
from maxtext.utils import elastic_utils
4445
# Placeholder: internal
4546

4647
# pylint: disable=too-many-positional-arguments
@@ -678,8 +679,30 @@ def run(config, recorder, diagnostic_config):
678679
def main(argv: Sequence[str]) -> None:
679680
config, recorder, diagnostic_config = initialize(argv)
680681
record_goodput(recorder, RECORD_JOB_START_TIME)
682+
683+
if config.elastic_enabled:
684+
max_logging.log("Elastic utils: Elastic training enabled.")
685+
686+
def elastic_train_func():
687+
"""Train function for elastic training.
688+
689+
Initializes variables and runs the train loop.
690+
"""
691+
elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv)
692+
run(
693+
elastic_config,
694+
elastic_recorder,
695+
elastic_diagnostic_config,
696+
)
697+
698+
train_func = elastic_utils.elastic_retry(config)(elastic_train_func)
699+
else:
700+
# Use the already initialized variables
701+
def train_func():
702+
run(config, recorder, diagnostic_config)
703+
681704
with maybe_monitor_goodput(config):
682-
run(config, recorder, diagnostic_config)
705+
train_func()
683706

684707

685708
if __name__ == "__main__":

src/maxtext/utils/elastic_utils.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for Elastic Training."""
16+
17+
import functools
18+
import jax
19+
from maxtext.utils import gcs_utils
20+
from maxtext.utils import max_logging
21+
import pathwaysutils
22+
from pathwaysutils.elastic import manager
23+
24+
25+
elastic_manager: manager.Manager | None = None
26+
27+
28+
def elastic_enabled(config) -> bool:
29+
"""Returns whether elastic mode is enabled."""
30+
return pathwaysutils.is_pathways_backend_used() and config.elastic_enabled
31+
32+
33+
def clean_up_checkpoints(checkpoint_dir: str):
34+
"""Cleans up incomplete checkpoints after an elastic event."""
35+
max_logging.log("Elastic utils: Checking for incomplete checkpoint after an elastic event...")
36+
checkpoint_dir = gcs_utils.add_trailing_slash(checkpoint_dir)
37+
38+
# 1. List the "directories" (steps)
39+
checkpoints = gcs_utils.gcs_list_directories(checkpoint_dir)
40+
41+
# 2. Filter for directories that are numbers
42+
checkpoints = [cp for cp in checkpoints if cp.isdigit()]
43+
44+
if not checkpoints:
45+
max_logging.log("Found no existing checkpoints. Continuing")
46+
return
47+
48+
# Sort naturally (numerical sort) and get the last one
49+
checkpoints.sort(key=int)
50+
latest_checkpoint_name = checkpoints[-1]
51+
latest_checkpoint_path = f"{checkpoint_dir}{latest_checkpoint_name}/"
52+
53+
max_logging.log(f"Checking latest checkpoint: {latest_checkpoint_path}")
54+
55+
# 3. Check for commit_success file
56+
success_markers = gcs_utils.gcs_glob_pattern(f"{latest_checkpoint_path}commit_success*")
57+
58+
if not success_markers:
59+
max_logging.log(f"No commit_success file found. Deleting {latest_checkpoint_path}...")
60+
gcs_utils.gcs_delete_directory(latest_checkpoint_path)
61+
else:
62+
max_logging.log(f"Found commit_success file. Keeping {latest_checkpoint_path}.")
63+
64+
65+
def live_devices():
66+
"""Returns the list of live devices."""
67+
global elastic_manager
68+
# If pathways is not used or elastic_manager is not initialized, return all devices
69+
if pathwaysutils.is_pathways_backend_used():
70+
if elastic_manager is None:
71+
elastic_manager = manager.Manager()
72+
# Filter devices that are in active slices
73+
return [d for d in jax.devices() if d.slice_index in elastic_manager.active_slice_indices]
74+
return jax.devices()
75+
76+
77+
def chain_callbacks(*funcs):
78+
"""Helper function to chain callbacks."""
79+
80+
def wrapper():
81+
for func in funcs:
82+
func()
83+
84+
return wrapper
85+
86+
87+
def elastic_retry(config, callback_fn=None):
88+
"""Decorator for elastic retry.
89+
90+
If an elastic event occurs, the decorator will retry the decorated function
91+
up to `config.elastic_max_retries` times.
92+
Before each retry, it cleans up partial checkpoints by calling
93+
`clean_up_checkpoints`. If `callback_fn` is provided, it is
94+
called after `clean_up_checkpoints`.
95+
96+
Args:
97+
config: Config object.
98+
callback_fn: Optional callback function to be called after
99+
`clean_up_checkpoints` on an elastic event.
100+
101+
Returns:
102+
A decorator for elastic retry.
103+
"""
104+
global elastic_manager
105+
if not elastic_enabled(config):
106+
msg = (
107+
"Elastic training requires the Pathways backend, and elastic_enabled"
108+
" must be set to True: current config.elastic_enabled:"
109+
f" {config.elastic_enabled}, pathways backend used:"
110+
f" {pathwaysutils.is_pathways_backend_used()}"
111+
)
112+
raise ValueError(msg)
113+
114+
max_logging.log("Elastic Retry Enabled")
115+
if elastic_manager is None:
116+
elastic_manager = manager.Manager()
117+
118+
cleanup_partial = functools.partial(clean_up_checkpoints, config.checkpoint_dir)
119+
120+
if callback_fn is None:
121+
effective_callback = cleanup_partial
122+
else:
123+
effective_callback = chain_callbacks(cleanup_partial, callback_fn)
124+
125+
return elastic_manager.elastic_retry(
126+
max_retries=config.elastic_max_retries,
127+
timeout=config.elastic_timeout_seconds,
128+
on_elastic_event_callback=effective_callback,
129+
)

src/maxtext/utils/gcs_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Common GCS Utils needed by multiple modules"""
15+
"""Common GCS Utils needed by multiple modules"""
1616
import shutil
1717
import json
1818
import os
1919
import socket
2020
from pathlib import Path
2121
from etils import epath
2222
import uuid
23+
from concurrent.futures import ThreadPoolExecutor
2324

2425
import yaml
2526

@@ -168,6 +169,35 @@ def gcs_list_directories(directory_path):
168169
return directories
169170

170171

172+
def gcs_delete_directory(directory_path: str):
173+
"""Deletes a "directory" (all blobs with the prefix) from GCS.
174+
175+
Args:
176+
directory_path: The GCS path (gs://...) representing the "directory" to delete.
177+
"""
178+
if not _gcs_guard("gcs_delete_directory"):
179+
return
180+
storage_client = storage.Client()
181+
bucket_name, directory_prefix = parse_gcs_bucket_and_prefix(directory_path)
182+
bucket = storage_client.bucket(bucket_name)
183+
184+
# Ensures the prefix has a trailing slash to avoid deleting more than intended.
185+
if not directory_prefix.endswith("/"):
186+
directory_prefix += "/"
187+
188+
blobs = list(bucket.list_blobs(prefix=directory_prefix))
189+
if blobs:
190+
# Uses a ThreadPoolExecutor to delete blobs in parallel to match gsutil -m performance.
191+
def _delete_blob(blob):
192+
try:
193+
blob.delete()
194+
except Exception as e: # pylint: disable=broad-except
195+
max_logging.log(f"Error deleting blob {blob.name}: {e}")
196+
197+
with ThreadPoolExecutor(max_workers=32) as executor:
198+
executor.map(_delete_blob, blobs)
199+
200+
171201
def gcs_glob_pattern(pattern):
172202
"""
173203
Globs GCS files and returns a list of full GCS paths.

0 commit comments

Comments
 (0)