From 77f05922f99838f3dab3a5a3d5673fe9128e8813 Mon Sep 17 00:00:00 2001 From: Charles Li Date: Tue, 10 Mar 2026 00:46:05 +0000 Subject: [PATCH 1/5] Add nnx_train --- src/maxtext/trainers/pre_train/nnx_train.py | 690 ++++++++++++++++++++ 1 file changed, 690 insertions(+) create mode 100644 src/maxtext/trainers/pre_train/nnx_train.py diff --git a/src/maxtext/trainers/pre_train/nnx_train.py b/src/maxtext/trainers/pre_train/nnx_train.py new file mode 100644 index 0000000000..58355f3b48 --- /dev/null +++ b/src/maxtext/trainers/pre_train/nnx_train.py @@ -0,0 +1,690 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-native pre-training loop for MaxText. + +This module implements a pre-training loop that uses the Flax NNX API throughout, +in contrast to train.py which wraps NNX models inside Linen's TrainState. + + + Architecture + + ┌─────────────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ + │ Layer │ What it does │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ loss_fn / eval_loss_fn │ Forward-pass + cross-entropy; called directly on an nnx.Module │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ train_step │ Functional step — merges (graphdef, opt_state) → runs nnx.value_and_grad │ + │ │ → updates optimizer → returns new nnx.State │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ eval_step │ Same merge pattern, forward-only, no grads │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ _create_and_shard_optimizer │ Wraps model + optax tx in nnx.Optimizer, derives partition specs via │ + │ │ nnx.get_partition_spec, shards state with jax.jit(out_shardings=…) │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ _build_jit_steps │ Partially applies static (graphdef, config) then wraps with │ + │ │ jax.jit(in_shardings, out_shardings, donate_argnums=(0,1)) │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ _maybe_restore_checkpoint / │ Orbax round-trip using the NNX {"value": array} wire format │ + │ _maybe_save_checkpoint │ │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ train_loop │ Full loop: model → optimizer → data → checkpoint → JIT compile → step → │ + │ │ eval → log │ + ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ + │ main / initialize / run │ Entry-point boilerplate matching train.py conventions │ + └─────────────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ + + Key differences from train.py + + - No Linen TrainState — state lives in nnx.Optimizer (model params + optax state + step counter). + - Gradient computation uses nnx.value_and_grad, which is NNX-graph-aware. It differentiates only through + nnx.Param variables; non-differentiable NNX variables (RNGs, cache, …) are untouched. + - Gradient clipping uses optax.clip_by_global_norm directly, avoiding the Linen-TrainState coupling in + apply_gradient_clipping. + - JIT boundary: graphdef is a Python-static closure; only opt_state (a plain pytree of arrays) crosses the JIT + boundary with donate_argnums=(0,1) + - The JIT boundary uses split/merge so that graphdef is static and state is + donated as a pytree, preserving full sharding control via jax.jit shardings. + - Checkpointing saves/restores the raw nnx.State pytree via Orbax. + +Entry point: + python -m maxtext.trainers.pre_train.nnx_train [overrides…] +""" + +import contextlib +import datetime +import functools +import os +from typing import Any, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from absl import app +from flax import linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from jax.sharding import Mesh + +from maxtext.common import checkpointing, profiler +from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.data_loader import create_dataloader +from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag +from maxtext.common.gcloud_stub import is_decoupled, vertex_tensorboard_modules +from maxtext.common.goodput import ( + RECORD_JOB_END_TIME, + RECORD_JOB_START_TIME, + GoodputEvent, + create_goodput_recorder, + maybe_monitor_goodput, + maybe_record_goodput, + record_goodput, +) +from maxtext.common.metric_logger import MetricLogger +from maxtext.configs import pyconfig +from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator +from maxtext.optimizers import optimizers +from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding +from maxtext.utils.globals import EPS +from maxtext.utils.rampup_batch import create_rampup_manager + +_diag_modules = _cloud_diag() +diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules +VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() + + +# --------------------------------------------------------------------------- +# Loss computation +# --------------------------------------------------------------------------- + + +def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array): + """Compute cross-entropy loss for one batch using an NNX model. + + Args: + model: The NNX Transformer (or compatible) model. Called in-place; no + explicit params argument is needed because the NNX module carries state. + config: MaxText Config object. + data: Batch dict with keys "inputs", "inputs_position", "inputs_segmentation", + "targets", "targets_segmentation". + dropout_rng: PRNG key used to seed dropout layers. + + Returns: + (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics. + """ + rng1, aqt_rng = jax.random.split(dropout_rng) + + # Trim to micro-batch size (handles per_device_batch_size < 1 cases) + batch = {k: v[: config.micro_batch_size_to_train_on, :] for k, v in data.items()} + + logits = model( + decoder_input_tokens=batch["inputs"], + decoder_positions=batch["inputs_position"], + decoder_segment_ids=batch["inputs_segmentation"], + enable_dropout=config.enable_dropout, + ) + + one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size) + xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) + + # Zero out padding positions + target_mask = batch["targets_segmentation"] != 0 + xent = xent * target_mask + z_loss = z_loss * target_mask + + total_loss = jnp.sum(xent) + total_weights = jnp.sum(target_mask) + total_z_loss = jnp.sum(z_loss) / (total_weights + EPS) + + loss = total_loss / (total_weights + EPS) + + aux = { + "total_loss": total_loss, + "z_loss": total_z_loss, + "total_weights": total_weights, + } + return loss, aux + + +def eval_loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array): + """Evaluation variant of loss_fn (no dropout, full batch size).""" + batch = {k: v[: config.micro_batch_size_to_eval_on, :] for k, v in data.items()} + + logits = model( + decoder_input_tokens=batch["inputs"], + decoder_positions=batch["inputs_position"], + decoder_segment_ids=batch["inputs_segmentation"], + enable_dropout=False, + ) + + one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size) + xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) + + target_mask = batch["targets_segmentation"] != 0 + xent = xent * target_mask + z_loss = z_loss * target_mask + + total_loss = jnp.sum(xent) + total_weights = jnp.sum(target_mask) + total_z_loss = jnp.sum(z_loss) / (total_weights + EPS) + + loss = total_loss / (total_weights + EPS) + + aux = { + "total_loss": total_loss, + "z_loss": total_z_loss, + "total_weights": total_weights, + } + return loss, aux + + +# --------------------------------------------------------------------------- +# Train / eval steps (purely functional, JIT-able) +# --------------------------------------------------------------------------- + + +def train_step( + model_graphdef: nnx.graph.NodeDef, + opt_graphdef: nnx.graph.NodeDef, + model_state: nnx.State, + opt_state: nnx.State, + data: dict[str, jax.Array], + dropout_rng: jax.Array, + config, +): + """One training step: forward + backward + optimizer update. + + Args: + model_graphdef: Static NNX graph definition for the model (JIT closure). + opt_graphdef: Static NNX graph definition for the optimizer (JIT closure). + model_state: Mutable model parameter pytree (donated). + opt_state: Mutable optimizer state pytree (donated). + data: Batch of token IDs and metadata. + dropout_rng: PRNG key for dropout. + config: MaxText Config. + + Returns: + (new_model_state, new_opt_state): Updated pytrees. + metrics: Dict of scalar training metrics. + """ + model: nnx.Module = nnx.merge(model_graphdef, model_state) + optimizer: nnx.Optimizer = nnx.merge(opt_graphdef, opt_state) + + # Compute loss and gradients w.r.t. model parameters. + # nnx.value_and_grad differentiates only through nnx.Param variables, + # keeping non-differentiable state (RNGs, cache, etc.) frozen. + grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True) + (loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng) + + # Cast gradients to configured dtype before clipping / accumulation + raw_grads = jax.tree.map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) + + # Gradient clipping (implemented directly to avoid Linen TrainState dependency) + if config.gradient_clipping_threshold > 0: + clip_tx = optax.clip_by_global_norm(config.gradient_clipping_threshold) + grads, _ = clip_tx.update(raw_grads, clip_tx.init(raw_grads), None) + else: + grads = raw_grads + + # NNX 0.11+: update takes (model, grads) explicitly. + optimizer.update(model, grads) + + new_model_state = nnx.state(model) + new_opt_state = nnx.state(optimizer) + + scalar_metrics = { + "learning/loss": loss, + "learning/z_loss": aux["z_loss"], + "learning/total_weights": aux["total_weights"], + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + "learning/param_norm": max_utils.l2norm_pytree(nnx.state(model, nnx.Param)), + } + metrics = {"scalar": scalar_metrics, "scalars": {}} + return (new_model_state, new_opt_state), metrics + + +def eval_step( + model_graphdef: nnx.graph.NodeDef, + model_state: nnx.State, + data: dict[str, jax.Array], + dropout_rng: jax.Array, + config, +): + """One evaluation step: forward only, no gradient computation. + + Args: + model_graphdef: Static NNX graph definition for the model. + model_state: Current model parameter pytree (read-only). + data: Batch of token IDs and metadata. + dropout_rng: PRNG key (dropout disabled for eval, but kept for API symmetry). + config: MaxText Config. + + Returns: + metrics: Dict of scalar evaluation metrics. + """ + model: nnx.Module = nnx.merge(model_graphdef, model_state) + loss, aux = eval_loss_fn(model, config, data, dropout_rng) + + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/z_loss": aux["z_loss"], + "evaluation/total_loss": aux["total_loss"], + "evaluation/total_weights": aux["total_weights"], + } + } + return metrics + + +# --------------------------------------------------------------------------- +# Training-loop setup +# --------------------------------------------------------------------------- + + +def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh): + """Creates an nnx.Optimizer and returns sharded model + optimizer states. + + In NNX 0.11+, the optimizer does not hold a model reference, so model and + optimizer are kept as independent objects with separate graphdefs, state + pytrees, and sharding specs throughout the training loop. + + Args: + model: Sharded NNX model (already placed on devices). + config: MaxText Config. + mesh: JAX device mesh. + + Returns: + model_graphdef: Static NNX graph definition for the model. + opt_graphdef: Static NNX graph definition for the optimizer. + model_state: Sharded model parameter pytree (donated to JIT steps). + opt_state: Sharded optimizer state pytree (donated to JIT steps). + model_shardings: Partition specs for model_state. + opt_shardings: Partition specs for opt_state. + learning_rate_schedule: Learning-rate schedule function. + """ + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + # NNX 0.11+: wrt is mandatory; optimizer does not store a model reference. + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + + # Derive separate partition specs for model and optimizer. + model_graphdef, abstract_model_state = nnx.split(nnx.eval_shape(lambda: model)) + opt_graphdef, abstract_opt_state = nnx.split(nnx.eval_shape(lambda: optimizer)) + + with nn.logical_axis_rules(config.logical_axis_rules): + model_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_model_state), mesh, config.logical_axis_rules + ) + opt_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_opt_state), mesh, config.logical_axis_rules + ) + + _, model_state = nnx.split(model) + _, opt_state = nnx.split(optimizer) + + @functools.partial(jax.jit, out_shardings=(model_shardings, opt_shardings)) + def shard_states(ms, os): + return ms, os + + with mesh: + model_state, opt_state = shard_states(model_state, opt_state) + + return model_graphdef, opt_graphdef, model_state, opt_state, model_shardings, opt_shardings, learning_rate_schedule + + +def _get_first_step(opt_state: nnx.State) -> int: + """Extracts the current step counter from the optimizer state.""" + # nnx.Optimizer stores step as an nnx.Variable; its value is a scalar. + step_leaves = [v for k, v in opt_state.flat_state().items() if "step" in str(k)] + if step_leaves: + return int(step_leaves[0]) + return 0 + + +def _build_jit_steps( + config, + model_graphdef: nnx.graph.NodeDef, + opt_graphdef: nnx.graph.NodeDef, + mesh: Mesh, + model_shardings: Any, + opt_shardings: Any, + eval_data_iterator, +): + """JIT-compiles the train and eval step functions with sharding annotations. + + Returns: + p_train_step: JIT-compiled train step. + p_eval_step: JIT-compiled eval step (None if no eval data). + """ + data_sharding = sharding.get_input_data_sharding(config, mesh) + + # Partial application captures static graphdefs and config outside JIT. + _train_fn = functools.partial(train_step, model_graphdef, opt_graphdef, config=config) + _train_fn.__name__ = "nnx_train_step" + + p_train_step = jax.jit( + _train_fn, + in_shardings=(model_shardings, opt_shardings, data_sharding, None), + out_shardings=((model_shardings, opt_shardings), None), + donate_argnums=(0, 1), # donate both model_state and opt_state buffers + ) + + p_eval_step = None + if eval_data_iterator is not None: + # Eval only needs the model; optimizer state is not required. + _eval_fn = functools.partial(eval_step, model_graphdef, config=config) + _eval_fn.__name__ = "nnx_eval_step" + p_eval_step = jax.jit( + _eval_fn, + in_shardings=(model_shardings, data_sharding, None), + out_shardings=None, + donate_argnums=(), + ) + + return p_train_step, p_eval_step + + +def _wrap_state(state: nnx.State): + """Wraps each leaf in {"value": ...} to match the NNX checkpoint format.""" + return jax.tree.map(lambda v: {"value": v}, state, is_leaf=lambda n: isinstance(n, nnx.Variable)) + + +def _unwrap_state(raw): + """Unwraps {"value": ...} leaves back to plain arrays.""" + return jax.tree.map(lambda v: v["value"], raw, is_leaf=lambda x: isinstance(x, dict) and "value" in x) + + +def _maybe_restore_checkpoint(checkpoint_manager, model_state: nnx.State, opt_state: nnx.State, config, data_iterator): + """Restores model and optimizer states from an Orbax checkpoint if one exists. + + Checkpoint layout: {"model": , "optimizer": }, + with every leaf wrapped as {"value": }. + + Returns: + (model_state, opt_state, data_iterator, start_step) + """ + if checkpoint_manager is None: + return model_state, opt_state, data_iterator, 0 + + try: + import orbax.checkpoint as ocp # pylint: disable=import-outside-toplevel + + latest = checkpoint_manager.latest_step() + if latest is None: + max_logging.log("No existing checkpoint found; starting from scratch.") + return model_state, opt_state, data_iterator, 0 + + max_logging.log(f"Restoring NNX checkpoint from step {latest}.") + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=config.checkpoint_storage_concurrent_gb, + save_concurrent_gb=config.checkpoint_storage_concurrent_gb, + use_ocdbt=config.checkpoint_storage_use_ocdbt, + use_zarr3=config.checkpoint_storage_use_zarr3, + ) + ) + + target = {"model": _wrap_state(model_state), "optimizer": _wrap_state(opt_state)} + restore_args = ocp.checkpoint_utils.construct_restore_args(target) + checkpoint_dir = checkpoint_manager.directory / str(latest) + restored_raw = ckptr.restore(checkpoint_dir, item=target, restore_args=restore_args) + + restored_model_state = _unwrap_state(restored_raw["model"]) + restored_opt_state = _unwrap_state(restored_raw["optimizer"]) + return restored_model_state, restored_opt_state, data_iterator, int(latest) + + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"Checkpoint restore failed ({e}); starting from scratch.") + return model_state, opt_state, data_iterator, 0 + + +def _maybe_save_checkpoint( + checkpoint_manager, model_state: nnx.State, opt_state: nnx.State, config, data_iterator, step: int +): + """Saves model and optimizer states to an Orbax checkpoint.""" + if checkpoint_manager is None: + return + state_to_save = {"model": _wrap_state(model_state), "optimizer": _wrap_state(opt_state)} + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + +# --------------------------------------------------------------------------- +# Main training loop +# --------------------------------------------------------------------------- + + +def train_loop(config, recorder, state=None): + """NNX pre-training loop. + + Args: + config: MaxText Config. + recorder: Goodput recorder (may be None). + state: Unused; present for API symmetry with train.py. + + Returns: + Final optimizer state pytree. + """ + # ---- Model ---------------------------------------------------------------- + with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): + with nn.logical_axis_rules(config.logical_axis_rules): + model, mesh = model_creation_utils.create_nnx_model(config) + + # ---- Optimizer + sharding ------------------------------------------------- + with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): + model_graphdef, opt_graphdef, model_state, opt_state, model_shardings, opt_shardings, learning_rate_schedule = ( + _create_and_shard_optimizer(model, config, mesh) + ) + + # ---- Data --------------------------------------------------------------- + with jax.set_mesh(mesh): + data_iterator, eval_data_iterator = create_data_iterator(config, mesh) + rampup_manager = create_rampup_manager(config, checkpoint_manager=None) + data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager) + + # ---- Checkpointing ------------------------------------------------------- + logger = checkpointing.setup_checkpoint_logger(config) + checkpoint_dir = config.checkpoint_dir if config.enable_checkpointing else "" + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + checkpoint_dir, + config.enable_checkpointing, + config.async_checkpointing, + config.checkpoint_period, + config.dataset_type, + logger, + config.checkpoint_storage_use_ocdbt, + config.checkpoint_storage_use_zarr3, + config.enable_continuous_checkpointing, + config.max_num_checkpoints_to_keep, + config.checkpoint_storage_concurrent_gb, + config.enable_single_controller, + config.colocated_python_checkpointing, + config.enable_single_replica_ckpt_restoring, + ) + + model_state, opt_state, data_iterator, start_step = _maybe_restore_checkpoint( + checkpoint_manager, model_state, opt_state, config, data_iterator + ) + + # ---- JIT-compile steps ---------------------------------------------------- + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + p_train_step, p_eval_step = _build_jit_steps( + config, model_graphdef, opt_graphdef, mesh, model_shardings, opt_shardings, eval_data_iterator + ) + + # Trigger AOT compilation and print memory stats + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + shaped_batch = maxtext_utils.get_shaped_batch(config) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + example_rng = jax.jit(jax.random.fold_in)(init_rng, 0) + if config.compiled_trainstep_file == "": + compiled = p_train_step.lower(model_state, opt_state, shaped_batch, example_rng).compile() + compiled_stats = compiled.memory_analysis() + max_utils.print_compiled_memory_stats(compiled_stats) + + # ---- Profiler / logger ---------------------------------------------------- + prof = profiler.Profiler(config, offset_step=start_step) + metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) + + # Write train config params, num model params, and XLA flags to tensorboard + metric_logger.write_setup_info_to_tensorboard(model_state) + + # ---- Main loop ------------------------------------------------------------ + _job_completed_gracefully = False + try: + last_step_completion = datetime.datetime.now() + + for step in np.arange(start_step, config.steps): + prof.maybe_activate_profiler(step, opt_state) + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + + with maybe_record_goodput(recorder, GoodputEvent.STEP, step): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + (model_state, opt_state), metrics = p_train_step(model_state, opt_state, example_batch, nextrng) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + _maybe_save_checkpoint(checkpoint_manager, model_state, opt_state, config, data_iterator, step) + + # ---- Optional eval ------------------------------------------------------- + if ( + p_eval_step is not None + and config.eval_interval > 0 + and step > start_step + and (step + 1) % config.eval_interval == 0 + ): + assert eval_data_iterator + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + eval_step_count = 0 + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(model_state, eval_batch, nextrng) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + eval_step_count += 1 + + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} achieved.") + + prof.maybe_deactivate_profiler(step, opt_state) + + if step == start_step: + max_utils.print_mem_stats("After first step") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + + # Final checkpoint on loop completion + if config.save_checkpoint_on_completion: + _maybe_save_checkpoint( + checkpoint_manager, model_state, opt_state, config, data_iterator, step=int(config.steps - 1) + ) + if checkpoint_manager is not None: + checkpoint_manager.wait_until_finished() + + _job_completed_gracefully = True + + except exceptions.StopTraining as e: + max_logging.log(f"Training stopped: {str(e)}") + _job_completed_gracefully = True + + finally: + if _job_completed_gracefully: + record_goodput(recorder, RECORD_JOB_END_TIME) + metric_logger.flush_metrics_and_cleanup() + + return opt_state + + +# --------------------------------------------------------------------------- +# Entry-point helpers +# --------------------------------------------------------------------------- + + +def initialize(argv: Sequence[str]): + """Initialise hyperparameters and utility objects.""" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + + import tensorflow as tf # pylint: disable=import-outside-toplevel + + tf.config.set_visible_devices([], "GPU") + + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + + config = pyconfig.initialize(argv) + max_utils.print_system_information() + + if not config.enable_nnx: + max_logging.log("WARNING: nnx_train.py requires enable_nnx=True. Forcing it on.") + + if config.shard_mode == ShardMode.EXPLICIT: + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + + os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" + + vertex_tensorboard_manager = VertexTensorboardManager() + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) + + recorder = create_goodput_recorder(config) + + debug_config = debug_configuration.DebugConfig( + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) + ) + diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) + return config, recorder, diagnostic_config + + +def run(config, recorder, diagnostic_config): + """Run the NNX training job.""" + diagnostics_context = ( + contextlib.nullcontext() + if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag" + else diagnostic.diagnose(diagnostic_config) + ) + + with ( + diagnostics_context, + max_utils.maybe_get_transformer_engine_context(config), + ): + train_loop(config, recorder) + + +def main(argv: Sequence[str]) -> None: + config, recorder, diagnostic_config = initialize(argv) + record_goodput(recorder, RECORD_JOB_START_TIME) + with maybe_monitor_goodput(config): + run(config, recorder, diagnostic_config) + + +if __name__ == "__main__": + app.run(main) From c311ead8702bda175f89027b7129a9f2bcb24f0d Mon Sep 17 00:00:00 2001 From: Charles Li Date: Tue, 10 Mar 2026 17:04:20 +0000 Subject: [PATCH 2/5] Combine loss_fn for both train and evel --- src/maxtext/trainers/pre_train/nnx_train.py | 150 +++++++++++++++----- 1 file changed, 113 insertions(+), 37 deletions(-) diff --git a/src/maxtext/trainers/pre_train/nnx_train.py b/src/maxtext/trainers/pre_train/nnx_train.py index 58355f3b48..dc545ce218 100644 --- a/src/maxtext/trainers/pre_train/nnx_train.py +++ b/src/maxtext/trainers/pre_train/nnx_train.py @@ -21,9 +21,10 @@ Architecture ┌─────────────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ - │ Layer │ What it does │ + │function │ What it does │ ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ - │ loss_fn / eval_loss_fn │ Forward-pass + cross-entropy; called directly on an nnx.Module │ + │ loss_fn │ Forward-pass + cross-entropy; for both train and eval; │ + │ │ called directly on an nnx.Module │ ├─────────────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ │ train_step │ Functional step — merges (graphdef, opt_state) → runs nnx.value_and_grad │ │ │ → updates optimizer → returns new nnx.State │ @@ -79,7 +80,7 @@ from jax.sharding import Mesh from maxtext.common import checkpointing, profiler -from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import ShardMode from maxtext.common.data_loader import create_dataloader from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag from maxtext.common.gcloud_stub import is_decoupled, vertex_tensorboard_modules @@ -95,6 +96,7 @@ from maxtext.common.metric_logger import MetricLogger from maxtext.configs import pyconfig from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator +from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss from maxtext.optimizers import optimizers from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding from maxtext.utils.globals import EPS @@ -106,11 +108,11 @@ # --------------------------------------------------------------------------- -# Loss computation +# Loss computation for both train and eval # --------------------------------------------------------------------------- -def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array): +def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array, is_train=True): """Compute cross-entropy loss for one batch using an NNX model. Args: @@ -120,6 +122,7 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: data: Batch dict with keys "inputs", "inputs_position", "inputs_segmentation", "targets", "targets_segmentation". dropout_rng: PRNG key used to seed dropout layers. + is_train: True for train_step and False for eval_step. Returns: (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics. @@ -127,51 +130,109 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: rng1, aqt_rng = jax.random.split(dropout_rng) # Trim to micro-batch size (handles per_device_batch_size < 1 cases) - batch = {k: v[: config.micro_batch_size_to_train_on, :] for k, v in data.items()} + # decimate proportion of data when per_device_batch_size<1 + if is_train: + batch = {k: v[: config.micro_batch_size_to_train_on, :] for k, v in data.items()} + else: + batch = {k: v[: config.micro_batch_size_to_eval_on, :] for k, v in data.items()} + # Flax NNX model logits = model( decoder_input_tokens=batch["inputs"], decoder_positions=batch["inputs_position"], decoder_segment_ids=batch["inputs_segmentation"], - enable_dropout=config.enable_dropout, + encoder_images=batch["images"] if config.use_multimodal else None, + encoder_image_masks=batch["image_masks"] if config.use_multimodal and "image_masks" in batch else None, + enable_dropout=config.enable_dropout if is_train else False, + decoder_target_tokens=batch["targets"], + decoder_target_mask=batch["targets_segmentation"], ) - + intermediate_outputs = {} one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) - # Zero out padding positions - target_mask = batch["targets_segmentation"] != 0 - xent = xent * target_mask - z_loss = z_loss * target_mask + xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) + z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length")) - total_loss = jnp.sum(xent) - total_weights = jnp.sum(target_mask) - total_z_loss = jnp.sum(z_loss) / (total_weights + EPS) + # Mask out paddings at the end of each example. + xent = xent * (batch["targets_segmentation"] != 0) + z_loss = z_loss * (batch["targets_segmentation"] != 0) - loss = total_loss / (total_weights + EPS) + total_loss = jnp.sum(xent) + total_z_loss = jnp.sum(z_loss) + + total_weights = jnp.sum(batch["targets_segmentation"] != 0) + # If gradient accumulation is enabled, we don't need to divide total_loss + # by total_weights and then multiply the computed gradient by total_weights, + # since it's equivalent to computing the gradient from total_loss. + # This simplification reduces the number of operations and makes it easier + # for XLA to move all-reduce out of the gradient accumulation loop when use + # Zero1+GA to reduce communication overhead. + # EPS was used to avoid division by zero, but it's not needed when gradient + # accumulation is enabled since there's no division. + if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation: + loss = total_loss + else: + # When using Tunix gradient accumulation, we revert to standard normalization. + # Unlike the manual accumulation path above, Tunix (via optax.MultiSteps) expects + # a normalized loss for each step. It handles the accumulation state + # updates and scaling internally. + loss = total_loss / (total_weights + EPS) + + # We keep z-loss normalized by total_weights. + total_z_loss = total_z_loss / (total_weights + EPS) + + # Calculate and Add MTP Loss + mtp_loss = 0.0 + if config.mtp_num_layers > 0 and is_train: + mtp_loss = calculate_mtp_loss(intermediate_outputs, config) + loss += mtp_loss + + # get MoE load balance loss + moe_lb_loss = 0.0 + if config.num_experts > 1: + # Note: the key is affected by the model implementation + possible_keys = [ + ("intermediates", "decoder", "layers", "moe_lb_loss"), + ("intermediates", "decoder", "moe_layers", "moe_lb_loss"), + ] + + total_moe_lb_loss = 0.0 + found_loss = False + for nested_key in possible_keys: + total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0) + if total_moe_lb_loss != 0.0: + found_loss = True + break + + if not found_loss: + max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") + + moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss)) + loss += moe_lb_loss + + # get MoE routed bias term updates + moe_bias_updates = None + if config.routed_bias and config.routed_bias_update_rate > 0.0: + nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") + moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) + + # Add the model's primary output to the intermediates dict so it can be used + # by the acceptance rate calculation in eval_step. + intermediate_outputs["logits"] = logits aux = { + "intermediate_outputs": intermediate_outputs, "total_loss": total_loss, "z_loss": total_z_loss, "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "moe_bias_updates": moe_bias_updates, + "mtp_loss": mtp_loss, } return loss, aux - -def eval_loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: jax.Array): - """Evaluation variant of loss_fn (no dropout, full batch size).""" - batch = {k: v[: config.micro_batch_size_to_eval_on, :] for k, v in data.items()} - - logits = model( - decoder_input_tokens=batch["inputs"], - decoder_positions=batch["inputs_position"], - decoder_segment_ids=batch["inputs_segmentation"], - enable_dropout=False, - ) - - one_hot_targets = jax.nn.one_hot(batch["targets"], config.vocab_size) - xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) - + # Zero out padding positions target_mask = batch["targets_segmentation"] != 0 xent = xent * target_mask z_loss = z_loss * target_mask @@ -226,7 +287,7 @@ def train_step( # nnx.value_and_grad differentiates only through nnx.Param variables, # keeping non-differentiable state (RNGs, cache, etc.) frozen. grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True) - (loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng) + (loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True) # Cast gradients to configured dtype before clipping / accumulation raw_grads = jax.tree.map( @@ -279,16 +340,31 @@ def eval_step( metrics: Dict of scalar evaluation metrics. """ model: nnx.Module = nnx.merge(model_graphdef, model_state) - loss, aux = eval_loss_fn(model, config, data, dropout_rng) + loss, aux = loss_fn(model, config, data, dropout_rng, is_train=False) + + mtp_acceptance_rate = 0.0 + if config.mtp_eval_target_module > 0: + mtp_acceptance_rate = calculate_mtp_acceptance_rate(aux["intermediate_outputs"], config) + total_loss = aux["total_loss"] + z_loss = aux["z_loss"] + total_weights = aux["total_weights"] + moe_lb_loss = aux["moe_lb_loss"] + mtp_loss = aux["mtp_loss"] metrics = { "scalar": { "evaluation/loss": loss, - "evaluation/z_loss": aux["z_loss"], - "evaluation/total_loss": aux["total_loss"], - "evaluation/total_weights": aux["total_weights"], - } + "evaluation/z_loss": z_loss, + "evaluation/total_loss": total_loss, + "evaluation/total_weights": total_weights, + "evaluation/moe_lb_loss": moe_lb_loss, + "evaluation/mtp_loss": mtp_loss, + "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, + }, } + # if config.use_dpo: + # metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] + return metrics From 9f9629c9a9ed225b989d1ab161e77fb2e0b95c65 Mon Sep 17 00:00:00 2001 From: Charles Li Date: Thu, 12 Mar 2026 22:25:03 +0000 Subject: [PATCH 3/5] Support gradient_accumulation and align to latest train.py --- src/maxtext/trainers/pre_train/nnx_train.py | 183 ++++++++++++++++---- src/maxtext/utils/gradient_accumulation.py | 122 +++++++++++++ 2 files changed, 272 insertions(+), 33 deletions(-) diff --git a/src/maxtext/trainers/pre_train/nnx_train.py b/src/maxtext/trainers/pre_train/nnx_train.py index dc545ce218..0f3d73af61 100644 --- a/src/maxtext/trainers/pre_train/nnx_train.py +++ b/src/maxtext/trainers/pre_train/nnx_train.py @@ -93,13 +93,14 @@ maybe_record_goodput, record_goodput, ) -from maxtext.common.metric_logger import MetricLogger +from maxtext.common.metric_logger import MetricLogger, record_activation_metrics from maxtext.configs import pyconfig from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss from maxtext.optimizers import optimizers from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding from maxtext.utils.globals import EPS +from maxtext.utils.gradient_accumulation import nnx_gradient_accumulation_loss_and_grad from maxtext.utils.rampup_batch import create_rampup_manager _diag_modules = _cloud_diag() @@ -127,7 +128,7 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: Returns: (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics. """ - rng1, aqt_rng = jax.random.split(dropout_rng) + # rng1, aqt_rng = jax.random.split(dropout_rng) # Trim to micro-batch size (handles per_device_batch_size < 1 cases) # decimate proportion of data when per_device_batch_size<1 @@ -188,6 +189,24 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: mtp_loss = calculate_mtp_loss(intermediate_outputs, config) loss += mtp_loss + # get indexer loss + indexer_loss = 0.0 + if config.use_sparse_indexer and config.indexer_loss_scaling_factor > 0.0: + indexer_losses = [] + # Extract 'indexer_loss' from model intermediates. + # We check for paths ending in ('self_attention', 'indexer_loss'). + # This handles varying paths caused by different layer names. + for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs): + path_keys = tuple(k.key for k in path if hasattr(k, "key")) + if path_keys[-2:] == ("self_attention", "indexer_loss"): + indexer_losses.append(jnp.ravel(val)) + + if indexer_losses: + indexer_loss = jnp.mean(jnp.concatenate(indexer_losses)) + loss += indexer_loss + else: + max_logging.debug("No indexer loss found.") + # get MoE load balance loss moe_lb_loss = 0.0 if config.num_experts > 1: @@ -227,29 +246,12 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng: "z_loss": total_z_loss, "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, + "indexer_loss": indexer_loss, "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, } return loss, aux - # Zero out padding positions - target_mask = batch["targets_segmentation"] != 0 - xent = xent * target_mask - z_loss = z_loss * target_mask - - total_loss = jnp.sum(xent) - total_weights = jnp.sum(target_mask) - total_z_loss = jnp.sum(z_loss) / (total_weights + EPS) - - loss = total_loss / (total_weights + EPS) - - aux = { - "total_loss": total_loss, - "z_loss": total_z_loss, - "total_weights": total_weights, - } - return loss, aux - # --------------------------------------------------------------------------- # Train / eval steps (purely functional, JIT-able) @@ -282,18 +284,52 @@ def train_step( """ model: nnx.Module = nnx.merge(model_graphdef, model_state) optimizer: nnx.Optimizer = nnx.merge(opt_graphdef, opt_state) + if config.use_dpo: + # Need impl on NNX + pass + # state, reference_params = _split_dpo_state(state) + # state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + # extra_dpo_args = [reference_params] + # loss_fn = dpo_loss_fn # Compute loss and gradients w.r.t. model parameters. # nnx.value_and_grad differentiates only through nnx.Param variables, # keeping non-differentiable state (RNGs, cache, etc.) frozen. - grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True) - (loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True) + if config.gradient_accumulation_steps > 1: + loss, aux, raw_grads = nnx_gradient_accumulation_loss_and_grad(loss_fn, model, config, data, dropout_rng) + else: + if config.optimizer_memory_host_offload: + # Need impl on NNX + pass + # if config.use_dpo: + # reference_params = jax.device_put( + # reference_params, + # max_utils.with_memory_kind(reference_params_sharding, "device"), + # ) + # extra_dpo_args = [reference_params] + if config.shard_optimizer_over_data: + # Need impl on NNX + pass + # params = jax.tree.map( + # functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + # params, + # params_shardings, + # ) + grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True) + (loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True) # Cast gradients to configured dtype before clipping / accumulation raw_grads = jax.tree.map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, raw_grads, ) + intermediate_outputs = aux["intermediate_outputs"] + total_weights = aux["total_weights"] + moe_lb_loss = aux["moe_lb_loss"] + indexer_loss = aux["indexer_loss"] + z_loss = aux["z_loss"] + moe_bias_updates = aux["moe_bias_updates"] + mtp_loss = aux["mtp_loss"] # Gradient clipping (implemented directly to avoid Linen TrainState dependency) if config.gradient_clipping_threshold > 0: @@ -301,6 +337,32 @@ def train_step( grads, _ = clip_tx.update(raw_grads, clip_tx.init(raw_grads), None) else: grads = raw_grads + if config.optimizer_memory_host_offload: + # Need impl on NNX + pass + # state = state.replace( + # opt_state=jax.device_put( + # state.opt_state, + # jax.tree_util.tree_map( + # lambda x: x.with_memory_kind(kind="device"), + # state_mesh_shardings.opt_state, + # ), + # ) + # ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + # Need impl on NNX + # def move(path, value): + # max_logging.log(f"train.py: Moving f{path} to device") + # return value.with_memory_kind(kind="device") + + # state = state.replace( + # params=jax.device_put( + # state.params, + # jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + # ) + # ) # NNX 0.11+: update takes (model, grads) explicitly. optimizer.update(model, grads) @@ -308,15 +370,53 @@ def train_step( new_model_state = nnx.state(model) new_opt_state = nnx.state(optimizer) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + # Need impl on NNX + pass + # target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Flax 'sow' returns a tuple, so we take the first element [0]. + # Updates the shape to be aligned with state. + # moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + # new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + scalar_metrics = { "learning/loss": loss, - "learning/z_loss": aux["z_loss"], - "learning/total_weights": aux["total_weights"], - "learning/grad_norm": max_utils.l2norm_pytree(grads), - "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), - "learning/param_norm": max_utils.l2norm_pytree(nnx.state(model, nnx.Param)), + "learning/z_loss": z_loss, + "learning/moe_lb_loss": moe_lb_loss, + "learning/indexer_loss": indexer_loss, + "learning/mtp_loss": mtp_loss, + "learning/total_weights": total_weights, } - metrics = {"scalar": scalar_metrics, "scalars": {}} + if config.use_qk_clip: + # Apply QK-Clip + # Need impl on NNX + pass + # new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + + # Report max_logits metric + # global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) + # if global_max_logit is not None: + # scalar_metrics["learning/max_logits"] = global_max_logit + + if not config.optimizer_memory_host_offload: + scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) + scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(nnx.state(model, nnx.Param)) + if config.use_dpo: + scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] + metrics = { + "scalar": scalar_metrics, + "scalars": {}, + } + + if config.record_internal_nn_metrics: + record_activation_metrics(metrics, intermediate_outputs, config) + + if config.use_dpo: + # Need impl on NNX + pass + # new_state = _merge_dpo_state(new_state, reference_params) return (new_model_state, new_opt_state), metrics @@ -350,6 +450,7 @@ def eval_step( z_loss = aux["z_loss"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] + indexer_loss = aux["indexer_loss"] mtp_loss = aux["mtp_loss"] metrics = { "scalar": { @@ -358,6 +459,7 @@ def eval_step( "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights, "evaluation/moe_lb_loss": moe_lb_loss, + "evaluation/indexer_loss": indexer_loss, "evaluation/mtp_loss": mtp_loss, "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, @@ -415,8 +517,8 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh): _, opt_state = nnx.split(optimizer) @functools.partial(jax.jit, out_shardings=(model_shardings, opt_shardings)) - def shard_states(ms, os): - return ms, os + def shard_states(mshard, oshard): + return mshard, oshard with mesh: model_state, opt_state = shard_states(model_state, opt_state) @@ -608,7 +710,9 @@ def train_loop(config, recorder, state=None): shaped_batch = maxtext_utils.get_shaped_batch(config) init_rng = jax.random.PRNGKey(config.init_weights_seed) example_rng = jax.jit(jax.random.fold_in)(init_rng, 0) - if config.compiled_trainstep_file == "": + # Need imple below func on NNX + # maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (model_state, opt_state, shaped_batch, example_rng)) + if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded compiled = p_train_step.lower(model_state, opt_state, shaped_batch, example_rng).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -624,6 +728,7 @@ def train_loop(config, recorder, state=None): _job_completed_gracefully = False try: last_step_completion = datetime.datetime.now() + max_logging.info(f"Entering train loop from start_step={start_step}") for step in np.arange(start_step, config.steps): prof.maybe_activate_profiler(step, opt_state) @@ -631,7 +736,6 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) nextrng = jax.jit(jax.random.fold_in)(init_rng, step) - with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): (model_state, opt_state), metrics = p_train_step(model_state, opt_state, example_batch, nextrng) @@ -649,8 +753,10 @@ def train_loop(config, recorder, state=None): and (step + 1) % config.eval_interval == 0 ): assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop eval_data_iterator.reset() metric_logger.reset_eval_metrics() + eval_step_count = 0 for eval_batch in eval_data_iterator: if config.eval_steps > 0 and eval_step_count >= config.eval_steps: @@ -658,6 +764,7 @@ def train_loop(config, recorder, state=None): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(model_state, eval_batch, nextrng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") eval_step_count += 1 metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) @@ -678,6 +785,7 @@ def train_loop(config, recorder, state=None): checkpoint_manager, model_state, opt_state, config, data_iterator, step=int(config.steps - 1) ) if checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress checkpoint_manager.wait_until_finished() _job_completed_gracefully = True @@ -727,8 +835,10 @@ def initialize(argv: Sequence[str]): if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): vertex_tensorboard_manager.configure_vertex_tensorboard(config) + # Create the Goodput recorder recorder = create_goodput_recorder(config) + # Stack traces configurations debug_config = debug_configuration.DebugConfig( stack_trace_config=stack_trace_configuration.StackTraceConfig( collect_stack_trace=config.collect_stack_trace, @@ -741,13 +851,20 @@ def initialize(argv: Sequence[str]): def run(config, recorder, diagnostic_config): - """Run the NNX training job.""" + """Run the NNX training job. + + In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping. + """ + # Use nullcontext when diagnostics are stubbed or in decoupled mode diagnostics_context = ( contextlib.nullcontext() if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag" else diagnostic.diagnose(diagnostic_config) ) + if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag": + max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.") + with ( diagnostics_context, max_utils.maybe_get_transformer_engine_context(config), diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e4cad14906..5c68cbe27c 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +from flax import nnx from jax.sharding import NamedSharding from maxtext.common.common_types import ShardMode @@ -137,6 +138,127 @@ def reshape_to_microbatch_accumulations(batch_arr): return loss, aux, raw_grads +# --------------------------------------------------------------------------- +# Gradient accumulation helper for NNX +# --------------------------------------------------------------------------- + + +def nnx_gradient_accumulation_loss_and_grad(_loss_fn, model, config, data, dropout_rng): + """ + Calculates gradients using gradient accumulation. + + This function computes the gradient of `_loss_fn` over multiple microbatches + and accumulates them before returning a single, averaged gradient. It uses + `jax.lax.scan` for efficient accumulation on device. + + It also supports a `shard_optimizer_over_data` mode (e.g., ZeRO-1) where + parameters are cast to bf16 and sharded *before* the accumulation loop + to perform the all-gather in lower precision. + + Args: + _loss_fn: The loss function to differentiate. Its signature is expected + to be: `(model, config, data, dropout_rng, is_train=True)`. + config: Model and training configuration object. Must contain + `gradient_accumulation_steps` and `shard_optimizer_over_data`. + model: The model module. + data: A PyTree of batched data. The leading dimension is assumed + to be the total batch size (microbatch_size * num_accumulations). + dropout_rng: JAX PRNGKey for dropout. + extra_dpo_args: A tuple of extra arguments to pass to the loss function. + + Returns: + A tuple containing: + - total_loss (Array): The mean loss, averaged over all microbatches. + - final_aux (PyTree): Auxiliary outputs, summed across microbatches. + - raw_grads (PyTree): The accumulated and averaged gradients. + """ + + # For more efficient DP/ZeRO-1 + GA + # if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: + # ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) + # grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # else: + # ga_params_shardings = grad_shardings = params_shardings + + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints + # so that all-gather is done once in the lower precision before the gradient accumulation loop + if config.shard_optimizer_over_data: + + def convert_to_bf16(param): + if param.dtype == jnp.float32: + return param.astype(jnp.bfloat16) + return param + + ga_params = jax.tree.map(convert_to_bf16, params) + else: + ga_params = params + + # ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + + def accumulate_gradient(acc_grad_and_loss, data): + ga_params = acc_grad_and_loss["ga_params"] + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + + # as ga_params will change during train_step, always create a local_model + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, dropout_rng, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + + acc_grad_and_loss["rest_state"] = next_rest_state + acc_grad_and_loss["loss"] += aux["total_loss"] + acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] + acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"] + acc_grad_and_loss["grad"] = jax.tree.map(lambda x, y: x + y, cur_batch_gradient, acc_grad_and_loss["grad"]) + acc_grad_and_loss["total_weights"] += aux["total_weights"] + return acc_grad_and_loss, aux + + def reshape_to_microbatch_accumulations(batch_arr): + """Reshape [B, ...] → [num_microbatches, B//num_microbatches, ...].""" + num_microbatches = config.gradient_accumulation_steps + microbatch_shape = (num_microbatches, batch_arr.shape[0] // num_microbatches) + batch_arr.shape[1:] + return jnp.reshape(batch_arr, microbatch_shape) + + # def reshape_to_microbatch_accumulations(batch_arr): + # """Reshape global batch to microbatches, assuming batch axis is leading.""" + # num_microbatches = config.gradient_accumulation_steps + # microbatch_shape = (batch_arr.shape[0] // num_microbatches, num_microbatches) + batch_arr.shape[1:] + # reshaped_batch_arr = jnp.reshape(batch_arr, microbatch_shape) + # return jnp.swapaxes(reshaped_batch_arr, 0, 1) + + data = jax.tree.map(reshape_to_microbatch_accumulations, data) + init_grad = jax.tree.map(jnp.zeros_like, ga_params) + # init_grad = jax.tree.map(_maybe_shard_with_name, init_grad, grad_shardings) + init_grad_and_loss = { + "loss": 0.0, + "grad": init_grad, + "total_weights": 0, + "moe_lb_loss": 0.0, + "mtp_loss": 0.0, + "ga_params": ga_params, + } + init_grad_and_loss["rest_state"] = rest + + grad_and_loss, aux = jax.lax.scan( + accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps + ) + loss = ( + grad_and_loss["loss"] / grad_and_loss["total_weights"] + + grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps + + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps + ) + raw_grads = grad_and_loss["grad"] + raw_grads = jax.tree.map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) + aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + + nnx.update(model, grad_and_loss["rest_state"]) + + return loss, aux, raw_grads + + # GA helper functions def update_sharding_for_reduced(sharding: NamedSharding) -> NamedSharding: """ From 82365f71f86bbcc479a89316d20886412817732e Mon Sep 17 00:00:00 2001 From: Charles Li Date: Wed, 11 Mar 2026 18:49:34 +0000 Subject: [PATCH 4/5] Support muon --- src/maxtext/trainers/pre_train/nnx_train.py | 6 +++--- src/maxtext/utils/maxtext_utils.py | 22 ++++++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/maxtext/trainers/pre_train/nnx_train.py b/src/maxtext/trainers/pre_train/nnx_train.py index 0f3d73af61..561dc34130 100644 --- a/src/maxtext/trainers/pre_train/nnx_train.py +++ b/src/maxtext/trainers/pre_train/nnx_train.py @@ -464,8 +464,8 @@ def eval_step( "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - # if config.use_dpo: - # metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] + if config.use_dpo: + metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -497,7 +497,7 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh): learning_rate_schedule: Learning-rate schedule function. """ learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) - tx = optimizers.get_optimizer(config, learning_rate_schedule) + tx = optimizers.get_optimizer(config, learning_rate_schedule, model) # NNX 0.11+: wrt is mandatory; optimizer does not store a model reference. optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index dab8103a4f..43190aff88 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,6 +20,7 @@ import os from flax import linen as nn +from flax import nnx from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -1030,7 +1031,7 @@ def init_initial_state(model, tx, config, is_training, key): return init_decode_state(model.apply, model_vars) -def get_abstract_param(model, config): +def get_abstract_param(model: nn.Module | nnx.Module, config): """Get abstract model structure (name, shape) without materializing the weights to save memory""" with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): key = jax.random.PRNGKey(0) @@ -1039,14 +1040,17 @@ def get_abstract_param(model, config): config.model_name, batch_size=config.micro_batch_size_to_train_on ) audio_shape = mm_processor.get_dummy_audio_shape_for_init(config) - abstract_vars = jax.eval_shape( - model.init, - {"params": key, "dropout": key, "aqt": key}, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(input_shape, dtype=jnp.int32), - encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, - encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, - ) + if isinstance(model, nn.Module): + abstract_vars = jax.eval_shape( + model.init, + {"params": key, "dropout": key, "aqt": key}, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32), + encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, + encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, + ) + else: # nnx.Module + _, abstract_vars = nnx.split(nnx.eval_shape(lambda: model)) return abstract_vars From 09d3714da8f62668caab60405be8533f8818af4b Mon Sep 17 00:00:00 2001 From: Charles Li Date: Thu, 19 Mar 2026 23:48:57 +0000 Subject: [PATCH 5/5] NNX train_compile --- .../trainers/pre_train/nnx_train_compile.py | 265 +++++++++++++ tests/unit/nnx_train_compile_test.py | 355 ++++++++++++++++++ 2 files changed, 620 insertions(+) create mode 100644 src/maxtext/trainers/pre_train/nnx_train_compile.py create mode 100644 tests/unit/nnx_train_compile_test.py diff --git a/src/maxtext/trainers/pre_train/nnx_train_compile.py b/src/maxtext/trainers/pre_train/nnx_train_compile.py new file mode 100644 index 0000000000..d02a945c90 --- /dev/null +++ b/src/maxtext/trainers/pre_train/nnx_train_compile.py @@ -0,0 +1,265 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Save a Cross Ahead of Time Compiled (XAOT) version of nnx_train.py's train step. + +Mirrors train_compile.py but uses the Flax NNX API throughout, in contrast to +train_compile.py which relies on Linen's TrainState. + +Key differences from train_compile.py +-------------------------------------- +- No Linen TrainState. State lives in two separate pytrees: + model_state – nnx.State for the model parameters + opt_state – nnx.State for the optimizer (optax state + step counter) +- nnx.eval_shape creates abstract shapes without materialising parameters, so the + whole compilation is done without ever touching real hardware memory. +- Graphdefs (model_graphdef, opt_graphdef) are baked into the partial and are + Python-static across the JIT boundary; they are therefore not listed in + static_argnums. +- in_shardings / out_shardings follow the NNX train_step signature: + in: (model_state, opt_state, batch, rng) + out: ((model_state, opt_state), metrics) + +Entry point: + python -m maxtext.trainers.pre_train.nnx_train_compile [overrides…] +""" + +import functools +import os +from typing import Callable, Sequence + +import jax +from absl import app +from flax import linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning + +from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.configs import pyconfig +from maxtext.optimizers import optimizers +from maxtext.trainers.pre_train import nnx_train +from maxtext.trainers.pre_train.train_compile import get_topology_mesh, jit_and_compile, save_compiled, validate_config +from maxtext.utils import gcs_utils, max_utils, maxtext_utils, model_creation_utils, sharding + + +def create_nnx_rngs( + config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None +) -> nnx.Rngs: + """ + Create NNX Rngs + + Args: + config: the configuration + is_training: if the Rngs are for training + rng_key: the Rng key + + Returns: + The NNX Rngs + """ + if rng_key is None: + rng_key = jax.random.PRNGKey(config.init_weights_seed) + + if is_training: + return nnx.Rngs( + params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2) + ) + return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference + + +# --------------------------------------------------------------------------- +# Shaped inputs (NNX version) +# --------------------------------------------------------------------------- + + +def get_shaped_inputs_nnx(topology_mesh, config): + """Build abstract (shape-only) versions of nnx_train.train_step's inputs. + + Uses nnx.eval_shape to trace through model and optimizer construction so that + no actual parameters are allocated. The returned abstract states have + ShapeDtypeStruct leaves and can be passed directly to jax.jit.lower(). + + Returns: + model_graphdef: Static NNX graph definition for the model. + opt_graphdef: Static NNX graph definition for the optimizer. + abstract_model_state: Abstract model parameter pytree. + abstract_opt_state: Abstract optimizer state pytree. + model_shardings: Partition specs mapped to mesh shardings for model_state. + opt_shardings: Partition specs mapped to mesh shardings for opt_state. + data_sharding: Input-batch sharding. + shaped_batch: Shaped batch dict (ShapeDtypeStruct leaves). + shaped_rng: Shaped RNG key. + learning_rate_schedule: LR schedule (baked into the compiled object). + """ + # rng_key = jax.random.PRNGKey(config.init_weights_seed) + # rngs = nnx.Rngs(params=rng_key, dropout=1) + + # ------------------------------------------------------------------ + # 1. Abstract model via nnx.eval_shape — no parameters materialised. + # ------------------------------------------------------------------ + + def get_nnx_create_model_fn(config, mesh=None, devices=None) -> Callable: + """Creates the function for NNX model creation.""" + + def _create_model(): + # is_training = model_mode == MODEL_MODE_TRAIN + # rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) + rng_key = jax.random.PRNGKey(config.init_weights_seed) + rngs = create_nnx_rngs(config, True, rng_key) + return model_creation_utils.from_config(config, devices, mesh, rngs=rngs, model_mode=MODEL_MODE_TRAIN) + + return _create_model + + with nn.logical_axis_rules(config.logical_axis_rules): + create_model_fn = get_nnx_create_model_fn(config, topology_mesh) + abstract_model = nnx.eval_shape(create_model_fn) + model_graphdef, abstract_model_state = nnx.split(abstract_model) + + # ------------------------------------------------------------------ + # 2. Abstract optimizer via nnx.eval_shape. + # ------------------------------------------------------------------ + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + # get_optimizer may inspect the model structure (e.g. for Muon); the abstract + # model has the same tree structure as the real one, so this is safe. + tx = optimizers.get_optimizer(config, learning_rate_schedule, abstract_model) + + def _build_optimizer(): + return nnx.Optimizer(abstract_model, tx, wrt=nnx.Param) + + abstract_optimizer = nnx.eval_shape(_build_optimizer) + opt_graphdef, abstract_opt_state = nnx.split(abstract_optimizer) + + # ------------------------------------------------------------------ + # 3. Partition specs → mesh shardings. + # ------------------------------------------------------------------ + with nn.logical_axis_rules(config.logical_axis_rules): + model_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_model_state), topology_mesh, config.logical_axis_rules + ) + opt_shardings = nn.logical_to_mesh_sharding( + nnx.get_partition_spec(abstract_opt_state), topology_mesh, config.logical_axis_rules + ) + + # ------------------------------------------------------------------ + # 4. Shaped batch and RNG. + # ------------------------------------------------------------------ + data_sharding = sharding.get_input_data_sharding(config, topology_mesh) + shaped_batch = maxtext_utils.get_shaped_batch(config) + + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) + + return ( + model_graphdef, + opt_graphdef, + abstract_model_state, + abstract_opt_state, + model_shardings, + opt_shardings, + data_sharding, + shaped_batch, + shaped_rng, + learning_rate_schedule, + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(argv: Sequence[str]) -> None: + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + print("Starting nnx_train_compile.py...", flush=True) + + # Parse and validate configuration + config = pyconfig.initialize(argv) + validate_config(config) + + # Create target mesh + topology_mesh = get_topology_mesh(config) + + # Print system information after building the compile topology to avoid + # prematurely initialising the backend. + max_utils.print_system_information() + + # Get shaped inputs + ( + model_graphdef, + opt_graphdef, + abstract_model_state, + abstract_opt_state, + model_shardings, + opt_shardings, + data_sharding, + shaped_batch, + shaped_rng, + _, # _learning_rate_schedule, + ) = get_shaped_inputs_nnx(topology_mesh, config) + + # Build the partial that matches what _build_jit_steps produces in nnx_train. + # graphdefs are static (captured in the Python closure) so they do not appear + # in static_argnums. + func_to_compile = functools.partial(nnx_train.train_step, model_graphdef, opt_graphdef, config=config) + func_to_compile.__name__ = "nnx_train_step" + + shaped_train_args = (abstract_model_state, abstract_opt_state, shaped_batch, shaped_rng) + shaped_train_kwargs = {} + + in_shard = (model_shardings, opt_shardings, data_sharding, None) + out_shard = ((model_shardings, opt_shardings), None) + static_argnums = () + donate_argnums = (0, 1) + + # Compile + print("Jitting and compiling NNX train step...", flush=True) + compiled = jit_and_compile( + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + config, + nn_partitioning.axis_rules(config.logical_axis_rules), + ) + print("Jitting and compilation complete!", flush=True) + + # Serialize and save the compiled object + if config.compiled_trainstep_file != "": + print("Saving compiled object...") + save_compiled(compiled, config.compiled_trainstep_file) + print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") + print("Finished nnx_train_compile.py successfully!", flush=True) + print(f"Cost analysis: {compiled.cost_analysis()}") + print(f"Memory analysis: {compiled.memory_analysis()}") + + # Dump HLO if requested + if config.dump_hlo: + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tests/unit/nnx_train_compile_test.py b/tests/unit/nnx_train_compile_test.py new file mode 100644 index 0000000000..f1b5cdfe8e --- /dev/null +++ b/tests/unit/nnx_train_compile_test.py @@ -0,0 +1,355 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Ahead-of-Time (AOT) compilation script using the NNX API. + +This module contains unit tests for `nnx_train_compile.py`, ensuring that +various model configurations and parallelism strategies can be successfully +compiled for different hardware topologies using the Flax NNX API. +""" + +import os.path +import unittest +from tempfile import gettempdir + +import pytest + +from maxtext.trainers.pre_train.nnx_train_compile import main as nnx_train_compile_main +from tests.utils.test_helpers import get_test_config_path + + +@pytest.mark.tpu_backend +class NnxTrainCompile(unittest.TestCase): + """Tests for the Ahead of Time Compilation functionality, nnx_train_compile.py""" + + @pytest.mark.cpu_only + def test_save_compiled_v4(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v4.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v4-8", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_v5e(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v5e.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_v5p_two_slices(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v5p_two_slices.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_v6e(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_v6e.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v6e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_tpu7x(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_tpu7x.pickle") + nnx_train_compile_main( + ( + None, + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=tpu7x-16", + "compile_topology_num_slices=1", + "ici_fsdp_parallelism=16", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_compiled_tpu7x_two_slices(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_tpu7x_two_slices.pickle") + nnx_train_compile_main( + ( + None, + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=tpu7x-8", + "compile_topology_num_slices=2", + "ici_fsdp_parallelism=4", + "ici_tensor_parallelism=2", + "dcn_data_parallelism=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_sequence_parallelism(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_compiled_sequence_parallelism.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "ici_sequence_parallelism=16", + "global_parameter_scale=32", + "per_device_batch_size=0.0625", + "max_target_length=65536", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_remat_full(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_remat_full.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v6e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=1024", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=full", + "use_iota_embed=true", + "global_parameter_scale=128", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_save_flash(self): + compiled_trainstep_file = "/tmp/nnx_test_save_flash" + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "remat_policy=custom", + "context=device", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_gpt3_6b(self): + compiled_trainstep_file = "/tmp/nnx_test_gpt3_6b" + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "compile_topology_num_slices=1", + "model_name=gpt3-6b", + "per_device_batch_size=1", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_dropping_bf16(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_dropping_bf16.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=mixtral-8x7b", + "sparse_matmul=False", + "capacity_factor=1", + "per_device_batch_size=4", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_megablox_bf16(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_megablox_bf16.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v6e-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=mixtral-8x7b", + "sparse_matmul=True", + "megablox=True", + "per_device_batch_size=4", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_deepseek_scanned_bf16(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_deepseek_scanned_bf16.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek3-test", + "sparse_matmul=True", + "megablox=False", + "per_device_batch_size=2", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "weight_dtype=bfloat16", + "scan_layers=True", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_moe_megablox_ring_ep_random(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "nnx_test_moe_megablox_ring_ep_random.pickle") + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-16", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek3-test", + "sparse_matmul=True", + "megablox=True", + "per_device_batch_size=4", + "max_target_length=128", + "use_ring_of_experts=True", + "use_random_routing=True", + "attention=flash", + "dtype=bfloat16", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + ) + + @pytest.mark.cpu_only + def test_pipeline_subset(self): + compiled_trainstep_file = "/tmp/nnx_test_pipeline_subset.pickle" + nnx_train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-128", + "compile_topology_num_slices=8", + "use_iota_embed=true", + "per_device_batch_size=1", + "max_target_length=1024", + "pipeline_parallel_layers=56", + "base_num_decoder_layers=61", + "ici_expert_parallelism=16", + "dcn_pipeline_parallelism=8", + "enable_nnx=True", + "pure_nnx_decoder=True", + ) + )