Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def train_loop(config, recorder, state=None):
_, setup_params, _ = nnx.split(state.model, nnx.Param, ...)
metric_logger_instance.write_setup_info_to_tensorboard(setup_params)

elastic_utils.record_elastic_reinit_end()

_job_completed_gracefully = False
try:
last_step_completion = datetime.datetime.now()
Expand Down Expand Up @@ -830,6 +832,12 @@ def get_train_func(config, recorder, diagnostic_config, argv):
if config.elastic_enabled:
max_logging.log("Elastic utils: Elastic training enabled.")

def on_elastic_event():
elastic_utils.record_elastic_event_start(recorder, config)

def on_slices_ready():
elastic_utils.record_elastic_wait_end_and_reinit_start(recorder)

def elastic_train_wrapper(argv: Sequence[str]) -> None:
"""Wrapper for elastic training initializes variables and runs the train loop."""
elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv)
Expand All @@ -839,7 +847,11 @@ def elastic_train_wrapper(argv: Sequence[str]) -> None:
elastic_diagnostic_config,
)

train_func = elastic_utils.elastic_retry(config)(functools.partial(elastic_train_wrapper, argv=argv))
train_func = elastic_utils.elastic_retry(
config,
callback_fn=on_elastic_event,
pre_callback_fn=on_slices_ready,
)(functools.partial(elastic_train_wrapper, argv=argv))
else:
# Use the already initialized variables
def train_func():
Expand Down
37 changes: 36 additions & 1 deletion src/maxtext/utils/elastic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,40 @@
from pathwaysutils.elastic import manager

elastic_manager: manager.Manager | None = None
pending_reinit_recorder = None
pending_elastic_event_type = None


def record_elastic_event_start(recorder, config) -> None:
"""Records start of an elastic scale up event."""
global pending_elastic_event_type
event_type = 'elastic_scale_up' if is_scale_up_event(config) else 'elastic_slice_down'
pending_elastic_event_type = event_type
if recorder:
recorder.record_custom_badput_event_start_time(custom_badput_event_type=event_type)


def record_elastic_wait_end_and_reinit_start(recorder) -> None:
"""Records end of elastic slice event and start of reinitialization event."""
global pending_reinit_recorder, pending_elastic_event_type
if pending_elastic_event_type is None:
return
event_type = pending_elastic_event_type
pending_elastic_event_type = None
if recorder:
recorder.record_custom_badput_event_end_time(custom_badput_event_type=event_type)
recorder.record_custom_badput_event_start_time(custom_badput_event_type='elastic_reinitialization')
pending_reinit_recorder = recorder


def record_elastic_reinit_end() -> None:
"""Records end of elastic reinitialization event."""
global pending_reinit_recorder
if pending_reinit_recorder is not None:
pending_reinit_recorder.record_custom_badput_event_end_time(
custom_badput_event_type='elastic_reinitialization'
)
pending_reinit_recorder = None


def elastic_enabled(config) -> bool:
Expand Down Expand Up @@ -121,7 +155,7 @@ def wrapper():
return wrapper


def elastic_retry(config, callback_fn=None):
def elastic_retry(config, callback_fn=None, pre_callback_fn=None):
"""Decorator for elastic retry.

If an elastic event occurs, the decorator will retry the decorated function
Expand Down Expand Up @@ -163,6 +197,7 @@ def elastic_retry(config, callback_fn=None):
max_retries=config.elastic_max_retries,
timeout=config.elastic_timeout_seconds,
minimum_slice_count=None if config.elastic_min_slice_count == -1 else config.elastic_min_slice_count,
pre_callback=pre_callback_fn,
on_elastic_event_callback=effective_callback,
)

Expand Down
95 changes: 95 additions & 0 deletions tests/unit/elastic_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def tearDown(self):
pathwaysutils.elastic.manager.Manager = self.original_manager_class
pathwaysutils.elastic.manager.ScaleUpSignalError = self.original_scale_up_signal_error
elastic_utils.elastic_manager = None
elastic_utils.pending_reinit_recorder = None
elastic_utils.pending_elastic_event_type = None
super().tearDown()

def test_elastic_enabled(self):
Expand Down Expand Up @@ -280,6 +282,99 @@ def test_elastic_retry_default_min_slices(self):
kwargs = self.fake_manager.elastic_retry.call_args.kwargs
self.assertIsNone(kwargs["minimum_slice_count"])

def test_elastic_retry_pre_callback_none_by_default(self):
"""pre_callback must be None when pre_callback_fn is not supplied."""
config = FakeConfig()
elastic_utils.elastic_manager = self.fake_manager

elastic_utils.elastic_retry(config)

kwargs = self.fake_manager.elastic_retry.call_args.kwargs
self.assertIsNone(kwargs["pre_callback"])

def test_elastic_retry_pre_callback_forwarded(self):
"""pre_callback_fn must be forwarded as pre_callback to the manager."""
config = FakeConfig()
elastic_utils.elastic_manager = self.fake_manager

fake_pre_callback = Mock()
elastic_utils.elastic_retry(config, pre_callback_fn=fake_pre_callback)

kwargs = self.fake_manager.elastic_retry.call_args.kwargs
self.assertIs(kwargs["pre_callback"], fake_pre_callback)

def test_record_elastic_event_start(self):
"""Tests recording an elastic slice down start."""
elastic_utils.elastic_manager = self.fake_manager
self.fake_manager.new_slice_event.is_set.return_value = False
fake_recorder = Mock()
config = FakeConfig()

elastic_utils.record_elastic_event_start(fake_recorder, config)

fake_recorder.record_custom_badput_event_start_time.assert_called_once_with(
custom_badput_event_type='elastic_slice_down'
)
self.assertEqual(elastic_utils.pending_elastic_event_type, 'elastic_slice_down')

def test_record_elastic_event_start_scale_up(self):
"""Tests recording an elastic slice scale up start."""
elastic_utils.elastic_manager = self.fake_manager
self.fake_manager.new_slice_event.is_set.return_value = True
fake_recorder = Mock()
config = FakeConfig()

elastic_utils.record_elastic_event_start(fake_recorder, config)

fake_recorder.record_custom_badput_event_start_time.assert_called_once_with(
custom_badput_event_type='elastic_scale_up'
)

def test_record_elastic_wait_end_and_reinit_start_noop_on_first_attempt(self):
"""Tests recording elastic event end and elastic reinit start."""
elastic_utils.pending_elastic_event_type = None
fake_recorder = Mock()

elastic_utils.record_elastic_wait_end_and_reinit_start(fake_recorder)

fake_recorder.record_custom_badput_event_end_time.assert_not_called()
fake_recorder.record_custom_badput_event_start_time.assert_not_called()
self.assertIsNone(elastic_utils.pending_reinit_recorder)

def test_record_elastic_wait_end_and_reinit_start(self):
"""Test recording end of slice down and start of reinit."""
elastic_utils.pending_elastic_event_type = 'elastic_slice_down'
fake_recorder = Mock()

elastic_utils.record_elastic_wait_end_and_reinit_start(fake_recorder)

fake_recorder.record_custom_badput_event_end_time.assert_called_once_with(
custom_badput_event_type='elastic_slice_down'
)
fake_recorder.record_custom_badput_event_start_time.assert_called_once_with(
custom_badput_event_type='elastic_reinitialization'
)
self.assertIs(elastic_utils.pending_reinit_recorder, fake_recorder)
self.assertIsNone(elastic_utils.pending_elastic_event_type)

def test_record_elastic_reinit_end(self):
"""Tests recording end of elastic reinit."""
fake_recorder = Mock()
elastic_utils.pending_reinit_recorder = fake_recorder

elastic_utils.record_elastic_reinit_end()

fake_recorder.record_custom_badput_event_end_time.assert_called_once_with(
custom_badput_event_type='elastic_reinitialization'
)
self.assertIsNone(elastic_utils.pending_reinit_recorder)

def test_record_elastic_reinit_end_on_cold_start(self):
"""Tests recording end of elastic reinit on cold start."""
elastic_utils.pending_reinit_recorder = None

elastic_utils.record_elastic_reinit_end()

def test_ensure_elastic_manager_initialized_readonly_config(self):
"""Tests that ensure_elastic_manager_initialized works with read-only config."""

Expand Down
Loading