From 23a73036ea75d4d5e1d4aa49fde8a0f6b577e957 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Wed, 13 May 2026 16:54:43 -0700 Subject: [PATCH] Add goodput elastic events --- src/maxtext/trainers/pre_train/train.py | 14 +++- src/maxtext/utils/elastic_utils.py | 37 +++++++++- tests/unit/elastic_utils_test.py | 95 +++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 2 deletions(-) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 951d10585d..69d27487d2 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -682,6 +682,8 @@ def train_loop(config, recorder, state=None): _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) metric_logger_instance.write_setup_info_to_tensorboard(setup_params) + elastic_utils.record_elastic_reinit_end() + _job_completed_gracefully = False try: last_step_completion = datetime.datetime.now() @@ -830,6 +832,12 @@ def get_train_func(config, recorder, diagnostic_config, argv): if config.elastic_enabled: max_logging.log("Elastic utils: Elastic training enabled.") + def on_elastic_event(): + elastic_utils.record_elastic_event_start(recorder, config) + + def on_slices_ready(): + elastic_utils.record_elastic_wait_end_and_reinit_start(recorder) + def elastic_train_wrapper(argv: Sequence[str]) -> None: """Wrapper for elastic training initializes variables and runs the train loop.""" elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv) @@ -839,7 +847,11 @@ def elastic_train_wrapper(argv: Sequence[str]) -> None: elastic_diagnostic_config, ) - train_func = elastic_utils.elastic_retry(config)(functools.partial(elastic_train_wrapper, argv=argv)) + train_func = elastic_utils.elastic_retry( + config, + callback_fn=on_elastic_event, + pre_callback_fn=on_slices_ready, + )(functools.partial(elastic_train_wrapper, argv=argv)) else: # Use the already initialized variables def train_func(): diff --git a/src/maxtext/utils/elastic_utils.py b/src/maxtext/utils/elastic_utils.py index 3028dfd9fc..0b09926a36 100644 --- a/src/maxtext/utils/elastic_utils.py +++ b/src/maxtext/utils/elastic_utils.py @@ -24,6 +24,40 @@ from pathwaysutils.elastic import manager elastic_manager: manager.Manager | None = None +pending_reinit_recorder = None +pending_elastic_event_type = None + + +def record_elastic_event_start(recorder, config) -> None: + """Records start of an elastic scale up event.""" + global pending_elastic_event_type + event_type = 'elastic_scale_up' if is_scale_up_event(config) else 'elastic_slice_down' + pending_elastic_event_type = event_type + if recorder: + recorder.record_custom_badput_event_start_time(custom_badput_event_type=event_type) + + +def record_elastic_wait_end_and_reinit_start(recorder) -> None: + """Records end of elastic slice event and start of reinitialization event.""" + global pending_reinit_recorder, pending_elastic_event_type + if pending_elastic_event_type is None: + return + event_type = pending_elastic_event_type + pending_elastic_event_type = None + if recorder: + recorder.record_custom_badput_event_end_time(custom_badput_event_type=event_type) + recorder.record_custom_badput_event_start_time(custom_badput_event_type='elastic_reinitialization') + pending_reinit_recorder = recorder + + +def record_elastic_reinit_end() -> None: + """Records end of elastic reinitialization event.""" + global pending_reinit_recorder + if pending_reinit_recorder is not None: + pending_reinit_recorder.record_custom_badput_event_end_time( + custom_badput_event_type='elastic_reinitialization' + ) + pending_reinit_recorder = None def elastic_enabled(config) -> bool: @@ -121,7 +155,7 @@ def wrapper(): return wrapper -def elastic_retry(config, callback_fn=None): +def elastic_retry(config, callback_fn=None, pre_callback_fn=None): """Decorator for elastic retry. If an elastic event occurs, the decorator will retry the decorated function @@ -163,6 +197,7 @@ def elastic_retry(config, callback_fn=None): max_retries=config.elastic_max_retries, timeout=config.elastic_timeout_seconds, minimum_slice_count=None if config.elastic_min_slice_count == -1 else config.elastic_min_slice_count, + pre_callback=pre_callback_fn, on_elastic_event_callback=effective_callback, ) diff --git a/tests/unit/elastic_utils_test.py b/tests/unit/elastic_utils_test.py index 60b9bcd900..c2c4673575 100644 --- a/tests/unit/elastic_utils_test.py +++ b/tests/unit/elastic_utils_test.py @@ -93,6 +93,8 @@ def tearDown(self): pathwaysutils.elastic.manager.Manager = self.original_manager_class pathwaysutils.elastic.manager.ScaleUpSignalError = self.original_scale_up_signal_error elastic_utils.elastic_manager = None + elastic_utils.pending_reinit_recorder = None + elastic_utils.pending_elastic_event_type = None super().tearDown() def test_elastic_enabled(self): @@ -280,6 +282,99 @@ def test_elastic_retry_default_min_slices(self): kwargs = self.fake_manager.elastic_retry.call_args.kwargs self.assertIsNone(kwargs["minimum_slice_count"]) + def test_elastic_retry_pre_callback_none_by_default(self): + """pre_callback must be None when pre_callback_fn is not supplied.""" + config = FakeConfig() + elastic_utils.elastic_manager = self.fake_manager + + elastic_utils.elastic_retry(config) + + kwargs = self.fake_manager.elastic_retry.call_args.kwargs + self.assertIsNone(kwargs["pre_callback"]) + + def test_elastic_retry_pre_callback_forwarded(self): + """pre_callback_fn must be forwarded as pre_callback to the manager.""" + config = FakeConfig() + elastic_utils.elastic_manager = self.fake_manager + + fake_pre_callback = Mock() + elastic_utils.elastic_retry(config, pre_callback_fn=fake_pre_callback) + + kwargs = self.fake_manager.elastic_retry.call_args.kwargs + self.assertIs(kwargs["pre_callback"], fake_pre_callback) + + def test_record_elastic_event_start(self): + """Tests recording an elastic slice down start.""" + elastic_utils.elastic_manager = self.fake_manager + self.fake_manager.new_slice_event.is_set.return_value = False + fake_recorder = Mock() + config = FakeConfig() + + elastic_utils.record_elastic_event_start(fake_recorder, config) + + fake_recorder.record_custom_badput_event_start_time.assert_called_once_with( + custom_badput_event_type='elastic_slice_down' + ) + self.assertEqual(elastic_utils.pending_elastic_event_type, 'elastic_slice_down') + + def test_record_elastic_event_start_scale_up(self): + """Tests recording an elastic slice scale up start.""" + elastic_utils.elastic_manager = self.fake_manager + self.fake_manager.new_slice_event.is_set.return_value = True + fake_recorder = Mock() + config = FakeConfig() + + elastic_utils.record_elastic_event_start(fake_recorder, config) + + fake_recorder.record_custom_badput_event_start_time.assert_called_once_with( + custom_badput_event_type='elastic_scale_up' + ) + + def test_record_elastic_wait_end_and_reinit_start_noop_on_first_attempt(self): + """Tests recording elastic event end and elastic reinit start.""" + elastic_utils.pending_elastic_event_type = None + fake_recorder = Mock() + + elastic_utils.record_elastic_wait_end_and_reinit_start(fake_recorder) + + fake_recorder.record_custom_badput_event_end_time.assert_not_called() + fake_recorder.record_custom_badput_event_start_time.assert_not_called() + self.assertIsNone(elastic_utils.pending_reinit_recorder) + + def test_record_elastic_wait_end_and_reinit_start(self): + """Test recording end of slice down and start of reinit.""" + elastic_utils.pending_elastic_event_type = 'elastic_slice_down' + fake_recorder = Mock() + + elastic_utils.record_elastic_wait_end_and_reinit_start(fake_recorder) + + fake_recorder.record_custom_badput_event_end_time.assert_called_once_with( + custom_badput_event_type='elastic_slice_down' + ) + fake_recorder.record_custom_badput_event_start_time.assert_called_once_with( + custom_badput_event_type='elastic_reinitialization' + ) + self.assertIs(elastic_utils.pending_reinit_recorder, fake_recorder) + self.assertIsNone(elastic_utils.pending_elastic_event_type) + + def test_record_elastic_reinit_end(self): + """Tests recording end of elastic reinit.""" + fake_recorder = Mock() + elastic_utils.pending_reinit_recorder = fake_recorder + + elastic_utils.record_elastic_reinit_end() + + fake_recorder.record_custom_badput_event_end_time.assert_called_once_with( + custom_badput_event_type='elastic_reinitialization' + ) + self.assertIsNone(elastic_utils.pending_reinit_recorder) + + def test_record_elastic_reinit_end_on_cold_start(self): + """Tests recording end of elastic reinit on cold start.""" + elastic_utils.pending_reinit_recorder = None + + elastic_utils.record_elastic_reinit_end() + def test_ensure_elastic_manager_initialized_readonly_config(self): """Tests that ensure_elastic_manager_initialized works with read-only config."""