Skip to content

Commit 8606de1

Browse files
committed
Add NaN and Inf checks for training loss
add configs tests modify only when training loss is inf or nan move values to bool update configs tests place the checks after write metrics for platforms fix lines fix lines update description fix types and checks fix whitespaces fix whitespace fix whitespaces again
1 parent 093ab89 commit 8606de1

5 files changed

Lines changed: 127 additions & 0 deletions

File tree

src/maxtext/common/metric_logger.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import json
2020
import os
21+
import sys
2122
import queue
2223
import enum
2324

@@ -123,6 +124,9 @@ def write_metrics(self, metrics, step, is_training=True):
123124
if self.config.managed_mldiagnostics:
124125
self.write_metrics_to_managed_mldiagnostics(metrics, step)
125126

127+
if is_training:
128+
self._maybe_abort_after_write_metrics(metrics)
129+
126130
def log_metrics(self, metrics, step, is_training):
127131
"""Logs metrics via max_logging."""
128132
if is_training:
@@ -214,6 +218,16 @@ def _is_profiler_boundary_step(self, step):
214218
}
215219
return step in boundary_steps
216220

221+
def _maybe_abort_after_write_metrics(self, metrics):
222+
""" This function checks whether we have nan or inf values in training"""
223+
loss = metrics["scalar"].get("learning/loss")
224+
if self.config.abort_on_nan_loss and np.isnan(loss):
225+
max_logging.log("Aborting training due to NaN loss.")
226+
sys.exit(1)
227+
if self.config.abort_on_inf_loss and np.isinf(loss):
228+
max_logging.log("Aborting training due to Inf loss.")
229+
sys.exit(1)
230+
217231
def write_metrics_locally(self, metrics, step):
218232
"""Writes metrics locally for testing."""
219233
with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file:

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,8 @@ decode_sampling_temperature: 1.
852852
eval_interval: -1 # the specific number of train step between eval_step
853853
eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data
854854
target_eval_loss: 0. # early stop once reaching target eval_loss
855+
abort_on_nan_loss: True # Check for NaN and abort if found in training loss
856+
abort_on_inf_loss: True # Check for Inf and abort if found in training loss
855857

856858
# Goodput parameters
857859
enable_goodput_recording: False

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,8 @@ class TrainingLoop(BaseModel):
11101110
0.0,
11111111
description="If set, training will stop early when this evaluation loss is reached.",
11121112
)
1113+
abort_on_nan_loss: bool = Field(True, description="Check for NaN values and abort training.")
1114+
abort_on_inf_loss: bool = Field(True, description="Check for Inf values and abort training.")
11131115
enable_dropout: bool = Field(True, description="Enables dropout in the model.")
11141116
dropout_rate: float = Field(0.0, ge=0.0, le=1.0, description="The dropout rate.")
11151117
enable_data_shuffling: bool = Field(True, description="Enables shuffling of the training data.")

tests/unit/configs_value_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,26 @@ def test_llama3_tokenizer_correction(self):
126126
config = pyconfig.initialize(argv)
127127
self.assertEqual(config.tokenizer_type, "tiktoken")
128128

129+
def test_abort_on_nan_loss_defaults_and_cli_override(self):
130+
"""Tests abort_on_nan_loss default from base.yml and CLI override parsing."""
131+
default_cfg = pyconfig.initialize(["", _BASE_CONFIG_PATH, "run_name=test_default_abort_nan"])
132+
self.assertEqual(default_cfg.abort_on_nan_loss, True)
133+
134+
cli_cfg = pyconfig.initialize(
135+
["", _BASE_CONFIG_PATH, "run_name=test_cli_abort_nan", "abort_on_nan_loss=False"]
136+
)
137+
self.assertEqual(cli_cfg.abort_on_nan_loss, False)
138+
139+
def test_abort_on_inf_loss_defaults_and_cli_override(self):
140+
"""Tests abort_on_inf_loss default from base.yml and CLI override parsing."""
141+
default_cfg = pyconfig.initialize(["", _BASE_CONFIG_PATH, "run_name=test_default_abort_inf"])
142+
self.assertEqual(default_cfg.abort_on_inf_loss, True)
143+
144+
cli_cfg = pyconfig.initialize(
145+
["", _BASE_CONFIG_PATH, "run_name=test_cli_abort_inf", "abort_on_inf_loss=False"]
146+
)
147+
self.assertEqual(cli_cfg.abort_on_inf_loss, False)
148+
129149
def test_initialize_pydantic_bad_keys(self):
130150
"""Test that `pydantic.ValidationError` is raised on keys not in MaxTextConfig"""
131151
with self.assertRaises(ValueError):
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for monitoring metrics"""
16+
import unittest
17+
from types import SimpleNamespace
18+
from unittest import mock
19+
20+
import numpy as np
21+
22+
from maxtext.common.metric_logger import MetricLogger
23+
24+
25+
class MetricLoggerAbortTest(unittest.TestCase):
26+
def _make_logger(self, abort_on_nan_loss, abort_on_inf_loss):
27+
logger = MetricLogger.__new__(MetricLogger) # skip __init__
28+
logger.config = SimpleNamespace(
29+
abort_on_nan_loss=abort_on_nan_loss,
30+
abort_on_inf_loss=abort_on_inf_loss,
31+
enable_tensorboard=True,
32+
metrics_file="/tmp/fake_metrics.jsonl",
33+
gcs_metrics=True,
34+
managed_mldiagnostics=True,
35+
)
36+
return logger
37+
38+
def _metrics(self, loss):
39+
return {"scalar": {"learning/loss": loss}}
40+
41+
@mock.patch("jax.process_index", return_value=0)
42+
def test_abort_on_nan_exits_after_writes(self, _):
43+
logger = self._make_logger(True, False)
44+
45+
with (
46+
mock.patch.object(logger, "log_metrics") as log_metrics,
47+
mock.patch.object(logger, "write_metrics_to_tensorboard") as tb,
48+
mock.patch.object(logger, "write_metrics_locally") as local,
49+
mock.patch.object(logger, "write_metrics_for_gcs") as gcs,
50+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics") as mldiag,
51+
):
52+
with self.assertRaises(SystemExit) as cm:
53+
logger.write_metrics(self._metrics(np.nan), step=1, is_training=True)
54+
55+
self.assertEqual(cm.exception.code, 1)
56+
log_metrics.assert_called_once()
57+
tb.assert_called_once()
58+
local.assert_called_once()
59+
gcs.assert_called_once()
60+
mldiag.assert_called_once()
61+
62+
@mock.patch("jax.process_index", return_value=0)
63+
def test_abort_on_inf_exits_after_writes(self, _):
64+
logger = self._make_logger(False, True)
65+
with mock.patch.object(logger, "log_metrics"), \
66+
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
67+
mock.patch.object(logger, "write_metrics_locally"), \
68+
mock.patch.object(logger, "write_metrics_for_gcs"), \
69+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"):
70+
with self.assertRaises(SystemExit):
71+
logger.write_metrics(self._metrics(np.inf), step=1, is_training=True)
72+
73+
def test_finite_loss_does_not_exit(self):
74+
logger = self._make_logger(True, True)
75+
with mock.patch.object(logger, "log_metrics"), \
76+
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
77+
mock.patch.object(logger, "write_metrics_locally"), \
78+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \
79+
mock.patch("jax.process_index", return_value=1): # skip gcs branch
80+
logger.write_metrics(self._metrics(1.23), step=1, is_training=True)
81+
82+
def test_abort_flags_disabled_does_not_exit(self):
83+
logger = self._make_logger(False, False)
84+
with mock.patch.object(logger, "log_metrics"), \
85+
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
86+
mock.patch.object(logger, "write_metrics_locally"), \
87+
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \
88+
mock.patch("jax.process_index", return_value=1):
89+
logger.write_metrics(self._metrics(np.nan), step=1, is_training=True)

0 commit comments

Comments
 (0)