diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 5ce7d5e936..56c8ba4462 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -18,6 +18,7 @@ import json import os +import sys import queue import enum @@ -123,6 +124,9 @@ def write_metrics(self, metrics, step, is_training=True): if self.config.managed_mldiagnostics: self.write_metrics_to_managed_mldiagnostics(metrics, step) + if is_training: + self._maybe_abort_after_write_metrics(metrics) + def log_metrics(self, metrics, step, is_training): """Logs metrics via max_logging.""" if is_training: @@ -214,6 +218,16 @@ def _is_profiler_boundary_step(self, step): } return step in boundary_steps + def _maybe_abort_after_write_metrics(self, metrics): + """ This function checks whether we have nan or inf values in training""" + loss = metrics["scalar"].get("learning/loss") + if self.config.abort_on_nan_loss and np.isnan(loss): + max_logging.log("Aborting training due to NaN loss.") + sys.exit(1) + if self.config.abort_on_inf_loss and np.isinf(loss): + max_logging.log("Aborting training due to Inf loss.") + sys.exit(1) + def write_metrics_locally(self, metrics, step): """Writes metrics locally for testing.""" with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file: diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 398df849fe..dccf72e7a0 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -852,6 +852,8 @@ decode_sampling_temperature: 1. eval_interval: -1 # the specific number of train step between eval_step eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data target_eval_loss: 0. # early stop once reaching target eval_loss +abort_on_nan_loss: True # Check for NaN and abort if found in training loss +abort_on_inf_loss: True # Check for Inf and abort if found in training loss # Goodput parameters enable_goodput_recording: False diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3df51ac106..4987033f07 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1110,6 +1110,8 @@ class TrainingLoop(BaseModel): 0.0, description="If set, training will stop early when this evaluation loss is reached.", ) + abort_on_nan_loss: bool = Field(True, description="Check for NaN values and abort training.") + abort_on_inf_loss: bool = Field(True, description="Check for Inf values and abort training.") enable_dropout: bool = Field(True, description="Enables dropout in the model.") dropout_rate: float = Field(0.0, ge=0.0, le=1.0, description="The dropout rate.") enable_data_shuffling: bool = Field(True, description="Enables shuffling of the training data.") diff --git a/tests/unit/metric_logger_abort_test.py b/tests/unit/metric_logger_abort_test.py new file mode 100644 index 0000000000..95eefca88b --- /dev/null +++ b/tests/unit/metric_logger_abort_test.py @@ -0,0 +1,89 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for monitoring metrics""" +import unittest +from types import SimpleNamespace +from unittest import mock + +import numpy as np + +from maxtext.common.metric_logger import MetricLogger + + +class MetricLoggerAbortTest(unittest.TestCase): + def _make_logger(self, abort_on_nan_loss, abort_on_inf_loss): + logger = MetricLogger.__new__(MetricLogger) # skip __init__ + logger.config = SimpleNamespace( + abort_on_nan_loss=abort_on_nan_loss, + abort_on_inf_loss=abort_on_inf_loss, + enable_tensorboard=True, + metrics_file="/tmp/fake_metrics.jsonl", + gcs_metrics=True, + managed_mldiagnostics=True, + ) + return logger + + def _metrics(self, loss): + return {"scalar": {"learning/loss": loss}} + + @mock.patch("jax.process_index", return_value=0) + def test_abort_on_nan_exits_after_writes(self, _): + logger = self._make_logger(True, False) + + with ( + mock.patch.object(logger, "log_metrics") as log_metrics, + mock.patch.object(logger, "write_metrics_to_tensorboard") as tb, + mock.patch.object(logger, "write_metrics_locally") as local, + mock.patch.object(logger, "write_metrics_for_gcs") as gcs, + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics") as mldiag, + ): + with self.assertRaises(SystemExit) as cm: + logger.write_metrics(self._metrics(np.nan), step=1, is_training=True) + + self.assertEqual(cm.exception.code, 1) + log_metrics.assert_called_once() + tb.assert_called_once() + local.assert_called_once() + gcs.assert_called_once() + mldiag.assert_called_once() + + @mock.patch("jax.process_index", return_value=0) + def test_abort_on_inf_exits_after_writes(self, _): + logger = self._make_logger(False, True) + with mock.patch.object(logger, "log_metrics"), \ + mock.patch.object(logger, "write_metrics_to_tensorboard"), \ + mock.patch.object(logger, "write_metrics_locally"), \ + mock.patch.object(logger, "write_metrics_for_gcs"), \ + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"): + with self.assertRaises(SystemExit): + logger.write_metrics(self._metrics(np.inf), step=1, is_training=True) + + def test_finite_loss_does_not_exit(self): + logger = self._make_logger(True, True) + with mock.patch.object(logger, "log_metrics"), \ + mock.patch.object(logger, "write_metrics_to_tensorboard"), \ + mock.patch.object(logger, "write_metrics_locally"), \ + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \ + mock.patch("jax.process_index", return_value=1): # skip gcs branch + logger.write_metrics(self._metrics(1.23), step=1, is_training=True) + + def test_abort_flags_disabled_does_not_exit(self): + logger = self._make_logger(False, False) + with mock.patch.object(logger, "log_metrics"), \ + mock.patch.object(logger, "write_metrics_to_tensorboard"), \ + mock.patch.object(logger, "write_metrics_locally"), \ + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \ + mock.patch("jax.process_index", return_value=1): + logger.write_metrics(self._metrics(np.nan), step=1, is_training=True)