From 82fc0a31998ad0e074b2b726ab7045710d5f3503 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 7 May 2026 19:26:41 +0000 Subject: [PATCH 1/4] NNX migration prep (4/N): sharding tools + post-training bugfixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part 1 — sharding diagnostics: - maxtext_utils.py: extend print_shardings_params to support NNX (nnx.State input) - run_sharding_dump.py: add --pure_nnx flag Part 2 — post-training bugfixes (NNX-side): - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts the individual image/audio/mask fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py / train_sft.py / train_rl.py: avoid nesting nnx.value_and_grad inside nnx.jit (Tunix's default trainer), which raises "graph structure of a node added to cached_partial was mutated" — refactor to jax.value_and_grad with explicit nnx.split / nnx.merge; train_rl.py also adds with_sharding_constraint + dtype-cast compat shims for jax 0.9 / tpu_inference Linen<->NNX checkpoint conversion utility and validation tool moved to a follow-up PR (PR4.5) to keep this change reviewable. --- src/maxtext/optimizers/optimizers.py | 4 +- .../post_train/distillation/train_distill.py | 78 +++++++++-------- .../trainers/post_train/rl/train_rl.py | 57 ++++++++++++- .../trainers/post_train/sft/train_sft.py | 80 +++++++++++++++++- src/maxtext/utils/maxtext_utils.py | 47 +++++++---- .../unit/distillation_scheduling_test.py | 44 ++++++++-- .../post_training/unit/train_distill_test.py | 84 +++++++++++++------ tests/utils/run_sharding_dump.py | 9 +- 8 files changed, 314 insertions(+), 89 deletions(-) diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 2ae7e5f8e5..9992d7674f 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu): else: updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params) - step_size = -1.0 * learning_rate_fn(count) + # learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped + # by optax.inject_hyperparams, it is passed as a pre-evaluated scalar). + step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn) # Finally, fold in step size. updates = jax.tree_util.tree_map(lambda x: step_size * x, updates) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 6f1264e69b..27b82b1f6b 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -274,30 +274,45 @@ def wrt_filter(path, x): # Inherits _shard_optimizer from PeftTrainer. def _train_step(self, model, optimizer, inputs): - """Overrides the main JIT block to natively handle ModelBundle module.""" + """Overrides the main JIT block to natively handle ModelBundle module. + Uses jax.value_and_grad with explicit split/merge to avoid nesting + nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign + conflicting outer_index values and raises: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + """ batch = self.gen_model_input_fn(inputs) + student = model.student_model + teacher = model.teacher_model current_step = model.training_step[...] - def loss_wrapper(student, teacher, batch): - if "teacher_output" in batch: - teacher_output = batch["teacher_output"] - else: - teacher_output = self.strategy.teacher_forward_fn( - model=teacher, - input_tokens=batch["input_tokens"], - positions=batch["positions"], - attention_mask=batch.get("attention_mask"), - decoder_segment_ids=batch.get("decoder_segment_ids"), - decoder_target_tokens=batch.get("targets", None), - decoder_target_mask=batch.get("targets_segmentation", None), - cache=None, - ) + # Run teacher inference outside of value_and_grad. + # The teacher is frozen (stop_gradient), so its output is a constant + # from the perspective of the student gradient computation. + if "teacher_output" in batch: + teacher_output = batch["teacher_output"] + else: + teacher_output = self.strategy.teacher_forward_fn( + model=teacher, + input_tokens=batch["input_tokens"], + positions=batch["positions"], + attention_mask=batch.get("attention_mask"), + decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), + cache=None, + ) + teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) - teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) + # Split student into differentiable params and non-differentiable rest. + # Capture graphdef outside of jax.value_and_grad for stable graph tracking. + student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...) + def loss_wrapper_pure(diff_params, rest): + local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True) student_output = self.strategy.student_forward_fn( - model=student, + model=local_student, input_tokens=batch["input_tokens"], positions=batch["positions"], attention_mask=batch.get("attention_mask"), @@ -306,29 +321,26 @@ def loss_wrapper(student, teacher, batch): decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) - # we should apply a mask for labels to disable segment-separator tokens labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) - return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) - - # Because student is the 0th argument, argnums=0 guarantees - # we only compute gradients for the student. - grad_fn = nnx.value_and_grad( - loss_wrapper, - argnums=nnx.DiffState(0, self.wrt_filter), - has_aux=True, - ) + loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) + # Capture updated non-param state (e.g. RNG counters) from local_student. + _, _, new_rest = nnx.split(local_student, self.wrt_filter, ...) + return loss, (aux, new_rest) - out, grads = grad_fn(model.student_model, model.teacher_model, batch) + grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True) + (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) - model.training_step.set_value(current_step + 1) + # Propagate updated non-param state back to student. + nnx.update(student, new_rest) - tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + optimizer.update(student, grads) - optimizer.update(model.student_model, grads) + model.training_step.set_value(current_step + 1) + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) if tunix_expects_grad_norm: - return out[0], out[1], optax.global_norm(grads) - return out[0], out[1] + return loss, aux, optax.global_norm(grads) + return loss, aux def _eval_step(self, model, inputs): """Evaluation only needs the student.""" diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 0af37dc10f..4f46b7920a 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -44,12 +44,14 @@ """ from __future__ import annotations +import contextlib from functools import wraps from typing import Any, Optional, Sequence import datasets import grain import jax +import jax.numpy as jnp import json import logging import os @@ -67,6 +69,48 @@ from tunix.rl.rollout import base_rollout from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner from tunix.sft import metrics_logger, profiler +import tunix.generate.utils as tunix_utils + + +@contextlib.contextmanager +def _tpu_inference_compat_patches(): + """Tactical compat shims for tpu_inference. + + tpu_inference has two call-site assumptions that no longer hold: + 1. jax.lax.with_sharding_constraint: assumes silent reshard on mismatch, + but current jax asserts when all mesh axes are Explicit. Fall back to + jax.sharding.reshard on the AssertionError. + 2. tunix._apply_dtype_cast: tpu_inference JaxEinsum defaults + param_dtype=float32 so its weights initialize as float32, but model + dtype is bfloat16; the cast upgraded synced bfloat16 weights to float32, + which then mismatched in the ragged paged attention kernel. Skip the + bf16->f32 upcast so synced weights stay bfloat16. + + Scoped to rl_train() so the patches don't leak into other importers of this + module. Drop both once tpu_inference is updated upstream. + """ + orig_wsc = jax.lax.with_sharding_constraint + orig_apply_dtype_cast = tunix_utils._apply_dtype_cast # pylint: disable=protected-access + + def _compat_wsc(x, shardings): + try: + return orig_wsc(x, shardings) + except AssertionError: + return jax.sharding.reshard(x, shardings) + + def _no_bf16_to_f32_cast(val, tgt_dtype, src_key): + if hasattr(val, "dtype") and val.dtype == jnp.bfloat16 and tgt_dtype == jnp.float32: + return val + return orig_apply_dtype_cast(val, tgt_dtype, src_key) + + jax.lax.with_sharding_constraint = _compat_wsc + tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access + try: + yield + finally: + jax.lax.with_sharding_constraint = orig_wsc + tunix_utils._apply_dtype_cast = orig_apply_dtype_cast # pylint: disable=protected-access + os.environ["TOKENIZERS_PARALLELISM"] = "0" @@ -418,6 +462,8 @@ def create_rl_components( "hf_overrides": trainer_config.vllm_hf_overrides, "enable_expert_parallel": sampler_config.enable_expert_parallel, "enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts + # Ensures vLLM model initializes with correct dtype (not float32 default) + "dtype": trainer_config.weight_dtype, }, rollout_vllm_sampling_kwargs={ "stop": trainer_config.stop_strings, @@ -539,6 +585,12 @@ def rl_train(argv: Sequence[str], kwargs: dict): trainer_devices: JAX devices for the trainer. sampler_devices: JAX devices for the sampler. """ + with _tpu_inference_compat_patches(): + _rl_train_impl(argv, kwargs) + + +def _rl_train_impl(argv: Sequence[str], kwargs: dict): + """rl_train body — kept separate so _tpu_inference_compat_patches wraps it cleanly.""" trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices( argv, kwargs ) @@ -563,7 +615,10 @@ def rl_train(argv: Sequence[str], kwargs: dict): max_train_steps = get_max_train_steps(trainer_config) # Create model tokenizer - model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + model_tokenizer = AutoTokenizer.from_pretrained( + trainer_config.tokenizer_path, + token=trainer_config.hf_access_token or None, + ) train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 3674ab70ff..75c7989d9f 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -35,7 +35,8 @@ eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16 """ -from typing import Sequence +import inspect +from typing import Any, Sequence from absl import app import os @@ -43,6 +44,7 @@ import optax import pathwaysutils +from flax import nnx from flax.linen import partitioning as nn_partitioning from orbax import checkpoint as ocp @@ -69,6 +71,78 @@ from maxtext.utils import model_creation_utils +class MaxTextPeftTrainer(peft_trainer.PeftTrainer): + """MaxText-specific PeftTrainer that avoids nested NNX transformations. + + Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside + nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index + values to graph nodes, resulting in: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + + This subclass overrides create_train_step_fn to use jax.value_and_grad + with an explicit split/merge pattern (matching MaxText's pre-training NNX + train_step), which avoids the nested NNX transformation issue entirely. + """ + + def create_train_step_fn(self): + """Creates a train step using jax.value_and_grad with explicit NNX split/merge.""" + loss_fn_ref = self.loss_fn + has_aux = self._has_aux + gen_fn = self.gen_model_input_fn + is_lora_enabled = self._lora_enabled + wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param + + # Detect whether Tunix's train() expects (loss, aux, grad_norm) or just + # (loss, aux) by inspecting the source of PeftTrainer._train_step. + tunix_expects_grad_norm = False + try: + source = inspect.getsource(peft_trainer.PeftTrainer._train_step) # pylint: disable=protected-access + tunix_expects_grad_norm = "grad_norm" in source + except (TypeError, OSError): + pass + + # Capture the graphdef once outside of JIT so that split/merge inside + # jax.value_and_grad can use a stable (non-traced) structural descriptor. + graphdef, _, _ = nnx.split(self.model, wrt, ...) + + def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any): + inputs = gen_fn(inputs) + + # Split model into differentiable params and non-differentiable rest. + # Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX + # transforms inside nnx.jit, which would corrupt outer_index tracking. + _, diff_params, rest = nnx.split(model, wrt, ...) + + def loss_wrapper(diff_params, rest, **inputs_kw): + local_model = nnx.merge(graphdef, diff_params, rest, copy=True) + out = loss_fn_ref(local_model, **inputs_kw) + # Capture updated non-param state (e.g. RNG counters) from local_model. + _, _, new_rest = nnx.split(local_model, wrt, ...) + if has_aux: + loss, aux = out + return loss, (aux, new_rest) + else: + return out, (None, new_rest) + + grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True) + (out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs) + + # Propagate updated non-param state (RNG counters, etc.) back to model. + nnx.update(model, new_rest) + + # Apply optimizer update. grads has the same nnx.State(wrt) structure + # as diff_params, which is compatible with optimizer.update. + optimizer.update(model, grads) + + aux_out = aux if has_aux else None + if tunix_expects_grad_norm: + return out_val, aux_out, optax.global_norm(grads) + return out_val, aux_out + + return train_step + + def get_tunix_config(mt_config): """Gets the Tunix training configurations from the MaxText config. @@ -110,6 +184,7 @@ def get_tunix_config(mt_config): checkpointing_options=checkpointing_options, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, + data_sharding_axis=tuple(mt_config.data_sharding), ) @@ -176,10 +251,9 @@ def setup_trainer_state(mt_config, goodput_recorder=None): # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + trainer = MaxTextPeftTrainer(model, optimizer, tunix_config) if mt_config.lora.lora_restore_path: trainer = lora_utils.restore_lora_from_path(trainer, mt_config) - trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index f536778126..0f07f5c14d 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1910,26 +1910,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No """ Print state shardings comparing Logical Definition vs Physical Result. """ - if not hasattr(params, "params"): - params = {"params": params} - if not hasattr(params_sharding, "params"): - params_sharding = {"params": params_sharding} - if logical_annotations and not hasattr(logical_annotations, "params"): - logical_annotations = {"params": logical_annotations} + if not isinstance(params, nnx.State): + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) - - message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" - max_logging.info(message) + if logical_annotations is not None: + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip( + leaves_params, leaves_sharding, leaves_logical + ): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) + + message = ( + f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + ) + max_logging.info(message) + else: + for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + + message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}" + max_logging.info(message) print(flush=True) diff --git a/tests/post_training/unit/distillation_scheduling_test.py b/tests/post_training/unit/distillation_scheduling_test.py index 21e22839b4..24b9b6d721 100644 --- a/tests/post_training/unit/distillation_scheduling_test.py +++ b/tests/post_training/unit/distillation_scheduling_test.py @@ -412,9 +412,15 @@ def __call__(self, x): self.assertEqual(int(bundle.training_step[...]), 2) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_increments_and_passes_step( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """_train_step passes pre-increment step to compute_loss and increments after.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() @@ -442,37 +448,54 @@ def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_g # Simulate resume from step 5 model_bundle.training_step.set_value(jnp.array(5, dtype=jnp.int32)) - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # nnx.split returns (graphdef, diff_params, rest); loss_wrapper_pure takes (diff_params, rest). + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # grad_fn returns ((loss, (aux, new_rest)), grads) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) trainer._train_step(model_bundle, optimizer, mock.Mock()) # Step should have incremented to 6 self.assertEqual(int(model_bundle.training_step[...]), 6) - # Trigger loss_wrapper to verify step=5 was passed to compute_loss + # Trigger loss_wrapper_pure to verify step=5 was passed to compute_loss. + # Signature is (diff_params, rest). loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + loss_wrapper(mock_diff_params, mock_rest) call_kwargs = trainer.strategy.compute_loss.call_args self.assertIn("step", call_kwargs.kwargs) self.assertEqual(int(call_kwargs.kwargs["step"]), 5) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_consecutive_train_steps_increment( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """training_step increments 0→1→2→3 across consecutive _train_step calls.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + # Use a real DistillationForwardOutput so jax.tree.map(stop_gradient, ...) works. + fake_teacher_output = distillation_utils.DistillationForwardOutput( + logits=jnp.zeros((1, 2, 4)), out_projection_activations=None + ) mock_batch = { "input_tokens": mock.Mock(), "positions": mock.Mock(), "targets": mock.Mock(), - "teacher_output": mock.Mock(), + "teacher_output": fake_teacher_output, } trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) @@ -480,7 +503,10 @@ def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_globa model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer = mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index 0c9204877c..6d0d2c1afa 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -162,9 +162,12 @@ def test_prepare_inputs_logic(self): @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_skips_teacher_forward_when_output_present( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is skipped when model_output is already in the batch.""" # 1. Initialize Trainer @@ -189,21 +192,28 @@ def test_train_step_skips_teacher_forward_when_output_present( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.teacher_forward_fn.assert_not_called() trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -215,9 +225,12 @@ def test_train_step_skips_teacher_forward_when_output_present( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_calls_teacher_forward_when_output_missing( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is called when model_output is missing from the batch.""" # 1. Initialize Trainer @@ -242,19 +255,27 @@ def test_train_step_calls_teacher_forward_when_output_missing( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_gn = mock.Mock() mock_global_norm.return_value = mock_gn + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure train_step_out = trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], @@ -266,8 +287,9 @@ def test_train_step_calls_teacher_forward_when_output_missing( decoder_target_mask=None, ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -291,8 +313,13 @@ def test_train_step_calls_teacher_forward_when_output_missing( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_passes_targets_segmentation( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm + ): """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -317,22 +344,30 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.create_labels.assert_called_once_with( mock_batch["targets"], targets_segmentation=mock_targets_segmentation ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -341,6 +376,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ decoder_target_mask=mock_targets_segmentation, cache=None, ) + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index 7d3156fe00..62c71a9b5b 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -59,9 +59,12 @@ flags.DEFINE_string("topology", None, "Specific topology to dump.") flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.") flags.DEFINE_string("custom_mesh_and_rule", None, "Specific custom_mesh_and_rule to dump.") +flags.DEFINE_bool("pure_nnx", False, "Use pure NNX model.") -def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple) -> None: +def run_single_dump( + model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple, pure_nnx: bool = False +) -> None: """Generate sharding json file for one specific model, topology, slice and rule.""" args = [ "python3", @@ -79,6 +82,8 @@ def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_ args.append(f"custom_mesh_and_rule={custom_mesh_and_rule}") if overrides: args.extend(overrides) + if pure_nnx: + args.append("pure_nnx=true") subprocess.run(args, check=True) @@ -117,7 +122,7 @@ def main(argv: Sequence[str]) -> None: print(" -> Sharding files already exist. Regenerating to overwrite.") try: - run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides) + run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides, pure_nnx=FLAGS.pure_nnx) except subprocess.CalledProcessError: print(f"!!! FAILED: {model_name} {topology} {num_slice} {custom_mesh_and_rule} overrides={overrides}") From 1fcabf7a126936cdb149690ee92b27d50771b9de Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 7 May 2026 21:31:28 +0000 Subject: [PATCH 2/4] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter Bidirectional Linen <-> NNX checkpoint conversion. Same on-disk shape both directions; round-trips preserve byte values. Top-level key mapping: - Linen params/params/ <-> NNX model/ (double-nesting, {value:} wrappers). - Linen opt_state <-> NNX optimizer/opt_state (params level on mu/nu). - Linen step <-> NNX optimizer/step. Layer structure: - scan_layers=True (default): stack layers_N -> layers tensor. - scan_layers=False: rename layers_N -> integer-keyed layers/{N}. NNX->Linen direction auto-detects which layer layout the source uses. --direction=auto picks Linen vs NNX from top-level keys. Pure utility addition. No production-code dependencies; PR5+ do not depend on this branch. Comparison utility split into PR4.6. --- .../linen_nnx_converter.py | 581 ++++++++++++ tests/unit/linen_nnx_converter_test.py | 869 ++++++++++++++++++ 2 files changed, 1450 insertions(+) create mode 100644 src/maxtext/checkpoint_conversion/linen_nnx_converter.py create mode 100644 tests/unit/linen_nnx_converter_test.py diff --git a/src/maxtext/checkpoint_conversion/linen_nnx_converter.py b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py new file mode 100644 index 0000000000..015d3b5a56 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py @@ -0,0 +1,581 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bidirectional conversion between Linen and NNX checkpoint formats. + +Top-level key mapping: + Linen → NNX: + params/params/ → model/ (remove double-nesting, rename, add {value:} wrappers) + opt_state → optimizer/opt_state (remove 'params' level from mu/nu) + step → optimizer/step (move inside optimizer) + + NNX → Linen: + model/ → params/params/ (strip {value:} wrappers, add double-nesting) + optimizer/opt_state → opt_state (add 'params' level to mu/nu) + optimizer/step → step (move to top level) + +Layer structure (--scan_layers): + linen_to_nnx: + scan_layers=True (default): stack layers_N arrays → 'layers' tensor with layer dim at axis 1 + scan_layers=False: rename layers_N → integer-keyed 'layers/{N}' + + nnx_to_linen (auto-detected): + Stacked 'layers' tensor → unstack along axis 1 → layers_N per-layer arrays + Integer-keyed layers/{N} → rename to layers_N + +Usage: + python linen_nnx_converter.py \\ + --source_path="gs://bucket/checkpoint/0/items" \\ + --target_path="gs://bucket/converted/" \\ + --direction=auto +""" + +import argparse +import os +import re +import time +from typing import Any + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import numpy as np +from etils import epath +import orbax.checkpoint as ocp + + +def log(message: str) -> None: + print(f"[linen_nnx_converter] {message}") + + +# ── Format detection ─────────────────────────────────────────────────────────── + + +def detect_format(state: dict) -> str: + """Detects checkpoint format ('linen' or 'nnx') from top-level keys.""" + # NNX: uses 'model' as the top-level params key + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Cannot detect checkpoint format: no 'model' or 'params' key. " f"Found: {list(state.keys())}") + + params = state["params"] + + # Linen: double-nested params/params/decoder + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Old NNX format: params/decoder (single-nested with value wrappers) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + if _has_value_wrappers(params): + return "nnx" + + if "optimizer" in state: + return "nnx" + if "opt_state" in state: + return "linen" + + raise ValueError( + f"Could not detect checkpoint format. Keys: {list(state.keys())}, " + f"params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +# ── Value wrapper helpers ────────────────────────────────────────────────────── + + +def _has_value_wrappers(tree: Any) -> bool: + """Returns True if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {value: array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _add_value_wrappers(tree: Any) -> Any: + """Recursively wraps leaf arrays in {value: array} (NNX nnx.Param format).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return tree # Already wrapped + return {k: _add_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_add_value_wrappers(item) for item in tree) + elif hasattr(tree, "shape") or isinstance(tree, np.ndarray): + return {"value": tree} + else: + return tree + + +# ── Layer structure helpers ──────────────────────────────────────────────────── + + +def _stack_layers(decoder: dict) -> tuple[dict, bool]: + """Stacks per-layer parameters (layers_N) into a single 'layers' dict at axis 0. + + Returns (result_dict, was_stacked). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder, False + + sorted_indices = sorted(layer_indices.keys()) + num_layers = len(sorted_indices) + log(f" Found {num_layers} individual layers, stacking into 'layers'") + + def stack_arrays(layers_data: list) -> Any: + first = layers_data[0] + if hasattr(first, "shape") or isinstance(first, np.ndarray): + return np.stack([np.asarray(layers_data[i]) for i in range(len(layers_data))], axis=0) + elif isinstance(first, dict): + result = {} + for key in first.keys(): + child_data = [layers_data[i].get(key) for i in range(len(layers_data))] + if all(c is not None for c in child_data): + result[key] = stack_arrays(child_data) + return result + else: + return first + + layers_data = [layer_indices[i] for i in sorted_indices] + stacked = stack_arrays(layers_data) + + result = dict(other_keys) + result["layers"] = stacked + return result, True + + +def _rename_layers_to_integer_keys(decoder: dict) -> dict: + """Converts layers_N keys to integer-keyed dict under 'layers' (no stacking). + + Converts {layers_0: {...}, layers_1: {...}} → {layers: {'0': {...}, '1': {...}}}. + Used for scan_layers=False linen→nnx conversion (Pattern C). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder + + sorted_indices = sorted(layer_indices.keys()) + log(f" Found {len(sorted_indices)} individual layers, renaming to integer-keyed 'layers/N'") + result = dict(other_keys) + result["layers"] = {str(i): layer_indices[i] for i in sorted_indices} + return result + + +def _transpose_layers_axes(tree: Any, src_axis: int, dst_axis: int) -> Any: + """Transposes the layers dimension in arrays within a tree (src_axis ↔ dst_axis).""" + if src_axis == dst_axis: + return tree + if isinstance(tree, dict): + return {k: _transpose_layers_axes(v, src_axis, dst_axis) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_transpose_layers_axes(item, src_axis, dst_axis) for item in tree) + elif hasattr(tree, "shape") and len(tree.shape) >= 2: + axes = list(range(len(tree.shape))) + axes[src_axis], axes[dst_axis] = axes[dst_axis], axes[src_axis] + result = np.transpose(np.asarray(tree), axes=axes) + log(f" Transposed: {tree.shape} → {result.shape}") + return result + else: + return tree + + +def _detect_num_layers(tree: Any, scan_axis: int) -> int | None: + """Detects num_layers from the first array with ndim > scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + shape = getattr(tree, "shape", None) or np.asarray(tree).shape + if len(shape) > scan_axis: + return shape[scan_axis] + return None + if isinstance(tree, dict): + for v in tree.values(): + result = _detect_num_layers(v, scan_axis) + if result is not None: + return result + return None + + +def _unstack_single_layer(tree: Any, idx: int, scan_axis: int) -> Any: + """Extracts a single layer by indexing at scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + arr = np.asarray(tree) + if arr.ndim > scan_axis: + return np.take(arr, idx, axis=scan_axis) + return arr + if isinstance(tree, dict): + return {k: _unstack_single_layer(v, idx, scan_axis) for k, v in tree.items()} + if isinstance(tree, (list, tuple)): + return type(tree)(_unstack_single_layer(v, idx, scan_axis) for v in tree) + return tree + + +def _convert_layers_to_linen_format(decoder: dict) -> dict: + """Converts NNX 'layers' back to Linen's layers_N format (auto-detects NNX style). + + Handles: + - Stacked tensor (Pattern B): layers/ + → layers_0, layers_1, ... (unstack along axis 1) + - Integer-keyed (Pattern C): layers/0, layers/1, ... + → layers_0, layers_1, ... (rename) + """ + if "layers" not in decoder: + return decoder + + layers_val = decoder["layers"] + other_keys = {k: v for k, v in decoder.items() if k != "layers"} + + if not isinstance(layers_val, dict): + # Already a non-dict (shouldn't happen normally), keep as-is + return decoder + + # Pattern C: integer-keyed per-layer dict → rename + if all(k.isdigit() for k in layers_val.keys()): + result = dict(other_keys) + for idx_str, layer_data in sorted(layers_val.items(), key=lambda x: int(x[0])): + result[f"layers_{idx_str}"] = layer_data + log(f" Renamed integer-keyed layers/N → layers_N ({len(layers_val)} layers)") + return result + + # Pattern B: stacked tensor (layer dim at axis 1) → unstack + num_layers = _detect_num_layers(layers_val, scan_axis=1) + if num_layers is None: + log(" WARNING: Could not detect num_layers for unstacking, keeping 'layers' as-is") + result = dict(other_keys) + result["layers"] = layers_val + return result + + result = dict(other_keys) + for i in range(num_layers): + result[f"layers_{i}"] = _unstack_single_layer(layers_val, idx=i, scan_axis=1) + log(f" Unstacked scanned 'layers' → layers_N ({num_layers} layers at axis 1)") + return result + + +# ── Optimizer state helpers ──────────────────────────────────────────────────── + + +def _convert_opt_state_linen_to_nnx(opt_state: Any) -> Any: + """Removes 'params' nesting from mu/nu in linen opt_state. + + NNX optimizer state has plain arrays (no {value:} wrappers). + Linen opt_state mirrors the params structure (params/decoder/...), + so we remove the 'params' level to get decoder/... directly. + """ + if isinstance(opt_state, dict): + result = {} + for k, v in opt_state.items(): + if k == "params": + # Remove this level by merging its contents up + converted = _convert_opt_state_linen_to_nnx(v) + if isinstance(converted, dict): + result.update(converted) + else: + result[k] = converted + else: + result[k] = _convert_opt_state_linen_to_nnx(v) + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_linen_to_nnx(item) for item in opt_state) + else: + return opt_state # Plain array or scalar — no value wrapper for opt_state + + +def _convert_opt_state_nnx_to_linen(opt_state: Any, depth: int = 0) -> Any: + """Adds 'params' nesting to mu/nu, removes any stray {value:} wrappers. + + NNX optimizer mu/nu contains decoder/... directly. + Linen expects mu/params/decoder/... (one 'params' level mirroring the params structure). + """ + if isinstance(opt_state, dict): + # Strip any {value:} wrappers in opt_state (shouldn't be there but handle gracefully) + if set(opt_state.keys()) == {"value"}: + inner = opt_state["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + + result = {} + for k, v in opt_state.items(): + converted = _convert_opt_state_nnx_to_linen(v, depth + 1) + # Add one 'params' level after mu/nu (mirrors linen's params structure) + if k in ("mu", "nu") and isinstance(converted, dict): + result[k] = {"params": converted} + else: + result[k] = converted + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_nnx_to_linen(item, depth + 1) for item in opt_state) + else: + return opt_state + + +# ── Main conversion functions ────────────────────────────────────────────────── + + +def convert_linen_to_nnx(state: dict, scan_layers: bool = True) -> dict: + """Converts Linen checkpoint to NNX format. + + Args: + state: Linen checkpoint dict with keys ['params', 'opt_state', 'step']. + scan_layers: If True (default), stack per-layer arrays and insert layer + dim at axis 1 (for NNX with scan_layers=True). + If False, rename layers_N → integer-keyed layers/N + (for NNX with scan_layers=False). + """ + result = {} + + if "params" in state: + linen_params = state["params"] + # Remove double 'params' nesting: params/params/decoder → decoder + if isinstance(linen_params, dict) and "params" in linen_params: + nnx_params = linen_params["params"] + log(" params: Removed double 'params' nesting (params/params → model)") + else: + nnx_params = linen_params + log(" params: No double nesting found") + + stripped = _strip_value_wrappers(nnx_params) + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + if scan_layers: + stripped[component], was_stacked = _stack_layers(stripped[component]) + if was_stacked and "layers" in stripped[component]: + log(f" {component}/layers: Transposing stacked (layers, ...) → (..., layers, ...) at axis 1") + stripped[component]["layers"] = _transpose_layers_axes(stripped[component]["layers"], src_axis=0, dst_axis=1) + else: + stripped[component] = _rename_layers_to_integer_keys(stripped[component]) + + result["model"] = _add_value_wrappers(stripped) + log(" model: Saved with {value:} wrappers under 'model' key") + + # optimizer: move step inside, keep opt_state + optimizer_dict = {} + if "step" in state: + optimizer_dict["step"] = state["step"] + log(f" optimizer/step: Moved from top-level (step={state['step']})") + if "opt_state" in state: + optimizer_dict["opt_state"] = _convert_opt_state_linen_to_nnx(state["opt_state"]) + log(" optimizer/opt_state: Removed 'params' nesting from mu/nu") + if optimizer_dict: + result["optimizer"] = optimizer_dict + + return result + + +def convert_nnx_to_linen(state: dict) -> dict: + """Converts NNX checkpoint to Linen format. + + Reads from 'model'/'optimizer' keys (or falls back to old 'params'/'opt_state' format). + Layer structure is auto-detected (stacked vs integer-keyed). + """ + result = {} + + model_key = "model" if "model" in state else "params" + if model_key in state: + nnx_params = state[model_key] + stripped = _strip_value_wrappers(nnx_params) + log(f" {model_key}: Removed {{value:}} wrappers") + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + stripped[component] = _convert_layers_to_linen_format(stripped[component]) + + # Add double 'params' nesting: decoder → params/params/decoder + result["params"] = {"params": stripped} + log(" params: Added double 'params' nesting (model → params/params)") + + # optimizer: extract step and opt_state back to top level + if "optimizer" in state: + optimizer = state["optimizer"] + if "step" in optimizer: + result["step"] = optimizer["step"] + log(" step: Extracted from optimizer/step to top level") + if "opt_state" in optimizer: + result["opt_state"] = _convert_opt_state_nnx_to_linen(optimizer["opt_state"]) + log(" opt_state: Added 'params' nesting to mu/nu") + elif "opt_state" in state: + # Backward compat: old format with opt_state at top level + result["opt_state"] = _convert_opt_state_nnx_to_linen(state["opt_state"]) + log(" opt_state: Converted from top-level opt_state (old format)") + + if "step" in state and "step" not in result: + result["step"] = state["step"] + + return result + + +# ── Checkpoint I/O ───────────────────────────────────────────────────────────── + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Loads checkpoint from local or GCS path.""" + log(f"Loading checkpoint from: {checkpoint_path}") + + checkpoint_dir = epath.Path(checkpoint_path) + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + metadata = ckptr.metadata(checkpoint_dir) + + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + log(f" Loaded keys: {list(state.keys())}") + return state + + +def save_checkpoint(state: dict, output_path: str) -> None: + """Saves checkpoint to local or GCS path.""" + log(f"Saving checkpoint to: {output_path}") + + output_dir = epath.Path(output_path) + output_dir.mkdir(exist_ok=True, parents=True) + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(output_dir, state, force=True) + log(" Checkpoint saved successfully") + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between Linen and NNX checkpoint formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source_path", + type=str, + required=True, + help="Path to source checkpoint items directory (e.g. gs://bucket/ckpt/0/items).", + ) + parser.add_argument( + "--target_path", + type=str, + required=True, + help="Path to save converted checkpoint.", + ) + parser.add_argument( + "--direction", + type=str, + choices=["auto", "linen_to_nnx", "nnx_to_linen"], + default="auto", + help="Conversion direction. 'auto' detects from source format.", + ) + parser.add_argument( + "--scan_layers", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For linen_to_nnx only: if True (default), stack per-layer arrays into a " + "scanned 'layers' tensor with layer dim at axis 1 (for NNX with scan_layers=True). " + "If False, rename layers_N to integer-keyed layers/N without stacking " + "(for NNX with scan_layers=False)." + ), + ) + + args = parser.parse_args() + + print("=" * 80) + print("Linen <-> NNX Checkpoint Converter") + print("=" * 80) + + start_time = time.time() + + state = load_checkpoint(args.source_path) + + if args.direction == "auto": + source_format = detect_format(state) + target_format = "nnx" if source_format == "linen" else "linen" + log(f"Auto-detected: {source_format} → {target_format}") + else: + source_format = args.direction.split("_to_")[0] + target_format = args.direction.split("_to_")[1] + log(f"Using specified direction: {source_format} → {target_format}") + + log(f"Converting: {source_format} → {target_format}") + if source_format == "linen": + log(f"scan_layers={args.scan_layers}") + + if source_format == "linen" and target_format == "nnx": + converted_state = convert_linen_to_nnx(state, scan_layers=args.scan_layers) + elif source_format == "nnx" and target_format == "linen": + converted_state = convert_nnx_to_linen(state) + else: + raise ValueError(f"Invalid conversion: {source_format} → {target_format}") + + save_checkpoint(converted_state, args.target_path) + + elapsed = time.time() - start_time + print("\n" + "=" * 80) + print(f"Conversion complete in {elapsed:.2f} seconds") + print(f" Source: {args.source_path}") + print(f" Target: {args.target_path}") + print(f" Direction: {source_format} → {target_format}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/linen_nnx_converter_test.py b/tests/unit/linen_nnx_converter_test.py new file mode 100644 index 0000000000..808990f8cf --- /dev/null +++ b/tests/unit/linen_nnx_converter_test.py @@ -0,0 +1,869 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linen_nnx_converter utilities.""" + +import unittest +import numpy as np +from unittest.mock import MagicMock, patch + +from maxtext.checkpoint_conversion.linen_nnx_converter import ( + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _add_value_wrappers, + _transpose_layers_axes, + _stack_layers, + convert_linen_to_nnx, + convert_nnx_to_linen, + _convert_opt_state_linen_to_nnx, + _convert_opt_state_nnx_to_linen, + load_checkpoint, + save_checkpoint, + main, +) + + +def _make_array(*shape): + """Helper to create a numpy array with given shape.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + +class TestDetectFormat(unittest.TestCase): + """Tests for the detect_format function.""" + + def test_raises_when_no_params_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_detects_nnx_format_via_model_key(self): + # NNX: top-level "model" key + state = {"model": {"decoder": {"layers": {}}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_format_double_nested(self): + state = {"params": {"params": {"decoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_format_single_nested_with_value_wrappers(self): + # Old NNX format: params/decoder with {value:} wrappers + arr = _make_array(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_encoder(self): + state = {"params": {"params": {"encoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_encoder_with_value_wrappers(self): + arr = _make_array(2, 2) + state = {"params": {"encoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_nnx_via_optimizer_key(self): + arr = _make_array(2, 2) + state = {"params": {"something": arr}, "optimizer": {"step": 0}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_opt_state(self): + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "opt_state": {"params": {"mu": {"decoder": {"kernel": arr}}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_optimizer_over_opt_state(self): + # "optimizer" key takes precedence for NNX detection + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "optimizer": {"step": 0, "opt_state": {}}, + } + self.assertEqual(detect_format(state), "nnx") + + def test_raises_on_undetectable_format(self): + state = {"params": {"some_unknown_key": 42}} + with self.assertRaises(ValueError): + detect_format(state) + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for the _has_value_wrappers helper.""" + + def test_returns_true_for_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"value": arr})) + + def test_returns_true_for_nested_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"mu": {"value": arr}})) + + def test_returns_false_for_plain_array(self): + # A plain array is not a {"value": ...} wrapper dict + self.assertFalse(_has_value_wrappers(_make_array(2, 2))) + + def test_returns_false_for_multi_key_dict(self): + arr = _make_array(2, 2) + self.assertFalse(_has_value_wrappers({"value": arr, "extra": arr})) + + def test_returns_false_for_non_array_value(self): + self.assertFalse(_has_value_wrappers({"value": "string"})) + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for the _strip_value_wrappers helper.""" + + def test_strips_single_wrapper(self): + arr = _make_array(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _make_array(2, 2) + wrapped = {"decoder": {"layers": {"kernel": {"value": arr}}}} + stripped = _strip_value_wrappers(wrapped) + np.testing.assert_array_equal(stripped["decoder"]["layers"]["kernel"], arr) + + def test_passes_through_plain_array(self): + arr = _make_array(2, 3) + result = _strip_value_wrappers(arr) + np.testing.assert_array_equal(result, arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _strip_value_wrappers([{"value": arr}]) + result_tuple = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result_list[0], arr) + np.testing.assert_array_equal(result_tuple[0], arr) + + def test_passes_through_non_array_value(self): + # A dict with key "value" but scalar content should not be unwrapped + d = {"value": 42} + result = _strip_value_wrappers(d) + self.assertEqual(result, d) + + +class TestAddValueWrappers(unittest.TestCase): + """Tests for the _add_value_wrappers helper.""" + + def test_wraps_array(self): + arr = _make_array(3, 4) + result = _add_value_wrappers(arr) + self.assertIsInstance(result, dict) + self.assertIn("value", result) + np.testing.assert_array_equal(result["value"], arr) + + def test_wraps_nested_arrays(self): + arr = _make_array(2, 2) + nested = {"decoder": {"layers": {"kernel": arr}}} + wrapped = _add_value_wrappers(nested) + self.assertEqual(set(wrapped["decoder"]["layers"]["kernel"].keys()), {"value"}) + np.testing.assert_array_equal(wrapped["decoder"]["layers"]["kernel"]["value"], arr) + + def test_idempotent_on_already_wrapped(self): + arr = _make_array(2) + already_wrapped = {"value": arr} + result = _add_value_wrappers(already_wrapped) + # Should not double-wrap + self.assertEqual(set(result.keys()), {"value"}) + np.testing.assert_array_equal(result["value"], arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _add_value_wrappers([arr]) + result_tuple = _add_value_wrappers((arr,)) + self.assertEqual(set(result_list[0].keys()), {"value"}) + self.assertEqual(set(result_tuple[0].keys()), {"value"}) + + def test_passes_through_non_array_scalars(self): + result = _add_value_wrappers(42) + self.assertEqual(result, 42) + result_str = _add_value_wrappers("text") + self.assertEqual(result_str, "text") + + +class TestTransposeLayersAxes(unittest.TestCase): + """Tests for the _transpose_layers_axes helper.""" + + def test_noop_when_same_axis(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=0) + np.testing.assert_array_equal(result, arr) + + def test_transposes_axis_0_to_1(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + self.assertEqual(result.shape, (2, 4, 3)) + + def test_transposes_axis_1_to_0(self): + arr = _make_array(2, 4, 3) + result = _transpose_layers_axes(arr, src_axis=1, dst_axis=0) + self.assertEqual(result.shape, (4, 2, 3)) + + def test_transposes_nested_dict(self): + arr = _make_array(4, 2, 3) + tree = {"decoder": {"layers": {"kernel": arr}}} + result = _transpose_layers_axes(tree, src_axis=0, dst_axis=1) + self.assertEqual(result["decoder"]["layers"]["kernel"].shape, (2, 4, 3)) + + def test_passes_through_1d_array(self): + arr = _make_array(5) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + # 1D array has no axis 1, should be returned unchanged + np.testing.assert_array_equal(result, arr) + + def test_handles_list(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes([arr], src_axis=0, dst_axis=1) + self.assertIsInstance(result, list) + self.assertEqual(result[0].shape, (2, 4, 3)) + + def test_handles_tuple(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes((arr,), src_axis=0, dst_axis=1) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0].shape, (2, 4, 3)) + + +class TestStackLayers(unittest.TestCase): + """Tests for the _stack_layers helper.""" + + def test_stacks_individual_layers(self): + arr0 = _make_array(3, 4) + arr1 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "layers_1": {"mlp": {"kernel": arr1}}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + stacked = result["layers"]["mlp"]["kernel"] + self.assertEqual(stacked.shape, (2, 3, 4)) + np.testing.assert_array_equal(stacked[0], arr0) + np.testing.assert_array_equal(stacked[1], arr1) + + def test_noop_when_no_layer_pattern(self): + arr = _make_array(3, 4) + decoder = {"layers": {"mlp": {"kernel": arr}}} + result, was_stacked = _stack_layers(decoder) + self.assertFalse(was_stacked) + self.assertIs(result, decoder) + + def test_preserves_non_layer_keys(self): + norm_weight = _make_array(4) + arr0 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "final_norm": {"scale": norm_weight}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("final_norm", result) + np.testing.assert_array_equal(result["final_norm"]["scale"], norm_weight) + + def test_stacks_three_layers(self): + arrays = [_make_array(2, 2) for _ in range(3)] + decoder = {f"layers_{i}": {"w": arrays[i]} for i in range(3)} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + stacked = result["layers"]["w"] + self.assertEqual(stacked.shape, (3, 2, 2)) + + def test_non_array_non_dict_leaf(self): + # Scalar leaf — stack_arrays returns first element + decoder = {"layers_0": {"count": 1}, "layers_1": {"count": 2}} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + + def test_with_missing_key_in_some_layers(self): + arr = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr, "bias": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, # no "bias" + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("kernel", result["layers"]["mlp"]) + + +class TestConvertLinenToNNX(unittest.TestCase): + """Tests for the convert_linen_to_nnx function.""" + + def _make_linen_state(self, add_opt_state=False): + """Creates a minimal Linen checkpoint structure.""" + arr = _make_array(2, 4, 3) + state = { + "step": 10, + "params": { + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + "decoder_norm": {"scale": _make_array(4)}, + } + } + }, + } + if add_opt_state: + state["opt_state"] = {"params": {"mu": {"decoder": {"layers": {"kernel": arr}}}}} + return state + + def test_converts_step_under_optimizer(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 10) + + def test_step_not_at_top_level(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + + def test_params_stored_under_model_key(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertIn("model", result) + self.assertNotIn("params", result) + + def test_removes_double_nesting(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # model should have 'decoder' directly, not 'params.decoder' + self.assertIn("decoder", result["model"]) + self.assertNotIn("params", result["model"]) + + def test_adds_value_wrappers(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # Arrays should be wrapped in {"value": array} + kernel = result["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_opt_state_under_optimizer(self): + state = self._make_linen_state(add_opt_state=True) + result = convert_linen_to_nnx(state) + self.assertIn("opt_state", result["optimizer"]) + # Linen opt_state had nested 'params' level; it should be removed + self.assertNotIn("params", result["optimizer"]["opt_state"]) + + def test_no_step_produces_no_optimizer_step(self): + arr = _make_array(2, 4, 3) + state = {"params": {"params": {"decoder": {"layers": {"kernel": arr}}}}} + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + self.assertIn("model", result) + + def test_no_double_nesting_still_converts(self): + # Linen state without double-nesting (unusual but handled) + arr = _make_array(2, 4) + state = {"params": {"decoder": {"layers": {"kernel": arr}}}} + result = convert_linen_to_nnx(state) + self.assertIn("decoder", result["model"]) + + def test_no_params_key_only_step(self): + state = {"step": 3} + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 3) + self.assertNotIn("model", result) + + def test_with_per_layer_params_stacked_and_transposed(self): + # Linen checkpoint with layers_0, layers_1 → stacked + transposed to axis 1 + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["decoder"]["layers"]["mlp"]["kernel"]["value"] + # Original (3, 4) stacked → (2, 3, 4), transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestConvertNNXToLinen(unittest.TestCase): + """Tests for the convert_nnx_to_linen function.""" + + def _make_nnx_state(self, add_opt_state=False): + """Creates an NNX checkpoint with 'model' and 'optimizer' keys. + + Uses 'attention' (not 'layers') as the sub-key so _convert_layers_to_linen_format + does not try to unstack the data. + """ + arr = _make_array(2, 4, 3) + state = { + "model": { + "decoder": { + "attention": {"wi": {"kernel": {"value": arr}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + "optimizer": {"step": 5}, + } + if add_opt_state: + state["optimizer"]["opt_state"] = { + "mu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "nu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + } + return state + + def test_converts_step(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + + def test_adds_double_nesting(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertIn("params", result["params"]) + self.assertIn("decoder", result["params"]["params"]) + + def test_strips_value_wrappers(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + kernel = result["params"]["params"]["decoder"]["attention"]["wi"]["kernel"] + self.assertIsInstance(kernel, np.ndarray) + + def test_converts_opt_state(self): + state = self._make_nnx_state(add_opt_state=True) + result = convert_nnx_to_linen(state) + self.assertIn("opt_state", result) + # mu/nu should get a 'params' level added + self.assertIn("params", result["opt_state"]["mu"]) + self.assertIn("params", result["opt_state"]["nu"]) + + def test_backward_compat_params_key(self): + # Old NNX format: "params" instead of "model", top-level "step" + arr = _make_array(2, 4, 3) + state = { + "step": 5, + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + } + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + self.assertIn("decoder", result["params"]["params"]) + + def test_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + +class TestRoundTrip(unittest.TestCase): + """Verifies that linen->nnx->linen round-trip preserves data.""" + + def test_linen_to_nnx_to_linen(self): + # Use "attention" (not "layers") so _convert_layers_to_linen_format + # does not try to unstack the dict as a stacked-layers tensor. + arr = _make_array(2, 4, 3) + linen_state = { + "step": 42, + "params": { + "params": { + "decoder": { + "attention": {"mlp": {"wi": {"kernel": arr}}}, + "norm": {"scale": _make_array(4)}, + } + } + }, + } + nnx_state = convert_linen_to_nnx(linen_state) + recovered_state = convert_nnx_to_linen(nnx_state) + + self.assertEqual(recovered_state["step"], 42) + recovered_kernel = recovered_state["params"]["params"]["decoder"]["attention"]["mlp"]["wi"]["kernel"] + np.testing.assert_array_equal(recovered_kernel, arr) + + def test_nnx_to_linen_to_nnx(self): + arr = _make_array(2, 4, 3) + nnx_state = { + "model": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + } + }, + "optimizer": {"step": 7}, + } + linen_state = convert_nnx_to_linen(nnx_state) + recovered_state = convert_linen_to_nnx(linen_state) + + self.assertEqual(recovered_state["optimizer"]["step"], 7) + recovered_kernel = recovered_state["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIn("value", recovered_kernel) + np.testing.assert_array_equal(recovered_kernel["value"], arr) + + +class TestConvertOptState(unittest.TestCase): + """Tests for the _convert_opt_state_linen_to_nnx and _convert_opt_state_nnx_to_linen helpers.""" + + def test_linen_to_nnx_removes_params_level(self): + arr = _make_array(3, 4) + opt_state = {"mu": {"params": {"decoder": {"kernel": arr}}}} + result = _convert_opt_state_linen_to_nnx(opt_state) + # 'params' key removed; decoder promoted + self.assertNotIn("params", result["mu"]) + self.assertIn("decoder", result["mu"]) + # Arrays are plain (no value wrappers in NNX opt_state) + np.testing.assert_array_equal(result["mu"]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": arr}}, {"decoder": {"kernel": arr}}] + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": arr}},) + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_non_array_non_dict(self): + # Scalars should be passed through unchanged + result = _convert_opt_state_linen_to_nnx(42) + self.assertEqual(result, 42) + + def test_linen_to_nnx_params_key_with_non_dict_value(self): + # When k == "params" but converted value is not a dict, store it as-is + opt_state = {"params": 99} + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIn("params", result) + self.assertEqual(result["params"], 99) + + def test_nnx_to_linen_adds_params_level_and_strips(self): + arr = _make_array(3, 4) + opt_state = { + "mu": {"decoder": {"kernel": {"value": arr}}}, + "nu": {"decoder": {"kernel": {"value": arr}}}, + } + result = _convert_opt_state_nnx_to_linen(opt_state) + # mu/nu should have 'params' nested inside + self.assertIn("params", result["mu"]) + self.assertIn("params", result["nu"]) + # Arrays unwrapped + kernel = result["mu"]["params"]["decoder"]["kernel"] + np.testing.assert_array_equal(kernel, arr) + + def test_nnx_to_linen_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": {"value": arr}}}] + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": {"value": arr}}},) + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_passes_through_scalars(self): + result = _convert_opt_state_nnx_to_linen("scalar_string") + self.assertEqual(result, "scalar_string") + + def test_nnx_to_linen_value_wrapper_with_non_array_inner(self): + # {"value": scalar} should NOT be unwrapped (only arrays get unwrapped) + d = {"value": 42} + result = _convert_opt_state_nnx_to_linen(d) + self.assertIn("value", result) + self.assertEqual(result["value"], 42) + + +class TestConvertLinenToNNXEncoder(unittest.TestCase): + """Tests encoder path in convert_linen_to_nnx.""" + + def test_converts_encoder_params(self): + arr = _make_array(2, 4, 3) + state = { + "params": { + "params": { + "encoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + } + } + } + } + result = convert_linen_to_nnx(state) + self.assertIn("encoder", result["model"]) + kernel = result["model"]["encoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_encoder_with_per_layer_stacking(self): + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "encoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["encoder"]["layers"]["mlp"]["kernel"]["value"] + # Stacked at axis 0 → (2, 3, 4), then transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestAdditionalEdgeCases(unittest.TestCase): + """Covers remaining edge cases.""" + + def test_detect_format_params_has_params_but_no_decoder_encoder(self): + # params["params"] exists but inner has no decoder/encoder -> falls through + # no optimizer/opt_state -> should raise + state = {"params": {"params": {"some_other_key": {}}}} + with self.assertRaises(ValueError): + detect_format(state) + + def test_detect_format_opt_state_returns_linen(self): + # Any state with "opt_state" (but no "model"/"optimizer") detects as linen + arr = _make_array(2) + state = { + "params": {"something": arr}, + "opt_state": {"mu": {"decoder": {"kernel": arr}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_add_value_wrappers_value_key_with_non_array(self): + # {"value": "text"} is not a wrapper (inner is not an array), recurse normally + d = {"value": "not_an_array"} + result = _add_value_wrappers(d) + self.assertEqual(result, {"value": "not_an_array"}) + + def test_convert_nnx_to_linen_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_already_has_params_nesting(self): + arr = _make_array(2, 4) + state = {"params": {"params": {"decoder": {"layers": {"kernel": {"value": arr}}}}}} + result = convert_nnx_to_linen(state) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_no_params_key(self): + state = {"optimizer": {"step": 8}} + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 8) + self.assertNotIn("params", result) + + +class TestLoadCheckpoint(unittest.TestCase): + """Tests for load_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_calls_checkpointer_and_returns_state(self, mock_epath, mock_ocp): + arr = _make_array(2, 2) + expected_state = {"params": arr, "step": 0} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {"params": arr} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.ArrayRestoreArgs.return_value = MagicMock() + + result = load_checkpoint("/tmp/test_ckpt") + + mock_epath.Path.assert_called_once_with("/tmp/test_ckpt") + mock_ocp.Checkpointer.assert_called_once() + mock_ckptr.metadata.assert_called_once_with(mock_path) + mock_ckptr.restore.assert_called_once() + self.assertEqual(result, expected_state) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_with_empty_tree_metadata(self, mock_epath, mock_ocp): + expected_state = {"step": 5} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + + result = load_checkpoint("/tmp/empty_ckpt") + + self.assertEqual(result["step"], 5) + + +class TestSaveCheckpoint(unittest.TestCase): + """Tests for save_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_creates_dir_and_saves(self, mock_epath, mock_ocp): + state = {"params": _make_array(2, 2), "step": 1} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/output") + + mock_epath.Path.assert_called_once_with("/tmp/output") + mock_path.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_ocp.PyTreeCheckpointer.assert_called_once() + mock_ckptr.save.assert_called_once_with(mock_path, state, force=True) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_passes_state_unchanged(self, mock_epath, mock_ocp): + state = {"step": 99, "params": {"decoder": {}}} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/out2") + + call_args = mock_ckptr.save.call_args + self.assertIs(call_args[0][1], state) + + +class TestMain(unittest.TestCase): + """Tests for the main() CLI entry point.""" + + def _run_main(self, argv): + with patch("sys.argv", ["prog"] + argv): + main() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_linen_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 1, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # NNX format: decoder at top level of model + self.assertIn("decoder", saved_state["model"]) + self.assertEqual(mock_save.call_args[0][1], "/dst") + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_nnx_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 2}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=nnx_to_linen"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Linen format: double nesting + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_linen_converts_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 3, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected linen → NNX format: model key + self.assertIn("decoder", saved_state["model"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_nnx_converts_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 4}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected nnx → Linen format + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_default_direction_is_auto(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + # No --direction arg -> defaults to "auto" + self._run_main(["--source_path=/src", "--target_path=/dst"]) + mock_save.assert_called_once() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_scan_layers_false(self, mock_load, mock_save): + arr = _make_array(3, 4) + mock_load.return_value = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx", "--no-scan_layers"]) + saved_state = mock_save.call_args[0][0] + # With scan_layers=False: integer-keyed layers/N + layers = saved_state["model"]["decoder"]["layers"] + self.assertIsInstance(layers, dict) + self.assertTrue(all(k.isdigit() for k in layers.keys())) + + +if __name__ == "__main__": + unittest.main() From 18e7be094276a4855866d8657ef979c2863e073b Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 28 Apr 2026 21:17:04 +0000 Subject: [PATCH 3/4] NNX: correctness fixes, enable feature paths, and vocab tiling on NNX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes (run as no-op while pure_nnx=False stays default): - nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing; call from ToLinen after nnx.update to fix "Cannot extract graph node from different trace level" when grad tracers leak into Variable._trace_state. - gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout = linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError. - normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity. - attentions.py / qwen3.py: callsites eps= -> epsilon=. - moe.py: per_expert_scale block moved into the unfused-kernel else branch (was scaling wo even when fused_kernel was active). - models.py: build MTP block as MultiTokenPredictionBlock(...) directly (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole to NNXDecoder instead of unpacking 5 fields. - gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan carry); use nnx.merge(..., copy=True) to avoid Variable reuse. - diloco.py: NNX-aware state handling — state.params -> state.model.filter (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params helper for jax.lax.cond pytree-structure parity. - train_compile.py: new _collect_nnx_activation_shardings helper (forward pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only traces __init__); NNX path now passes 2-arg shaped_train_args (no rng); diloco path patched to handle the 2-vs-3 length difference. - muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape. - nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu) retain tracers in Linen scope across re-traces. Skip jax.checkpoint and use a Python for-loop instead of jax.lax.scan when quantization is FP8. Makes FP8 quantization usable on the NNX path. - train.py (pre-train train_step): return nnx.state(new_state, nnx.Not (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for QK-Clip) don't break leaf-count parity with state_mesh_shardings. - llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer _norm RMSNorm (was missing on this norm only). - base.yml: add 4 pipeline-related logical_axis_rules — layers_outside _pipeline, layers_per_stage, num_activations, circular_repeats. Additive, no-op without use_nnx_pipeline=True. NNX feature enablements (clear all 17 "Pure NNX support has not been implemented yet" NotImplementedError sites by routing Linen-coupled utilities to the Linen path; their on-disk format is Linen): - layerwise_quantization.py (2 sites): operates on Linen-format checkpoints via DeepSeek*ToLinen layers. - lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen tree shape; LoRA adapters on disk are Linen. - standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses state.opt_state[0]._replace(mu=..., nu=...) — Linen-only. - generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and _save_decode_checkpoint use state.params["params"]["decoder"] — Linen. - convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree paths (.params['params'], .opt_state.mu['params']). - maxengine.py (3 sites): inference engine uses state.params and serves Linen-format inference checkpoints. - grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route to Linen with a clear log warning since NNX-format checkpoints will fail at restore time. Vocab tiling on NNX (real implementation, not just routing): - models.py: add Transformer.logits_from_hidden_states on the NNX Transformer class — wraps NNXDecoder.apply_output_head with the token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states. - vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per chunk. The NNX model carries its parameters internally so no explicit FSDP gather is needed (unlike the Linen gathered_params pattern). MVP uses default autograd; custom_vjp memory-savings optimization is a follow-up if backward memory becomes a concern. - train.py (NNX loss_fn): replace the NotImplementedError with the call to vocab_tiling_nnx_loss using hidden_states from intermediates. - pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1 and enable_nnx validation guards (no longer needed). DPO + NNX retained as NotImplementedError but with a much more informative message (points users at pure_nnx=False workaround). Full implementation is deferred — needs a new TrainState shape carrying both policy and reference NNX models plus an NNX dpo_loss_fn. Stats: 26 source files modified, +406 / -171 lines. Linen invariant verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False; Linen-path UTs unaffected (3 pre-existing failures on the parent branch remain unchanged — sharding_compare_test::deepseek2-16b, optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two _slices x2). All "Pure NNX support has not been implemented yet" NotImplementedError sites cleared (was 17, now 0). --- .../convert_gpt3_ckpt_from_paxml.py | 15 +-- src/maxtext/configs/base.yml | 7 ++ src/maxtext/configs/pyconfig_deprecated.py | 3 +- src/maxtext/configs/types.py | 3 +- src/maxtext/experimental/rl/grpo_trainer.py | 37 +++--- src/maxtext/inference/maxengine/maxengine.py | 22 ++-- src/maxtext/layers/attentions.py | 4 +- src/maxtext/layers/moe.py | 6 +- src/maxtext/layers/nnx_decoders.py | 30 ++++- src/maxtext/layers/nnx_wrappers.py | 35 ++++++ src/maxtext/layers/normalizations.py | 14 ++- src/maxtext/models/gpt_oss.py | 5 +- src/maxtext/models/llama2.py | 1 + src/maxtext/models/models.py | 13 +++ src/maxtext/models/olmo3.py | 4 +- src/maxtext/models/qwen3.py | 4 +- src/maxtext/models/qwen3_5.py | 4 +- src/maxtext/trainers/diloco/diloco.py | 59 ++++++++-- src/maxtext/trainers/pre_train/train.py | 22 +++- .../trainers/pre_train/train_compile.py | 38 ++++++- .../utils/generate_param_only_checkpoint.py | 26 ++--- src/maxtext/utils/gradient_accumulation.py | 21 +++- src/maxtext/utils/layerwise_quantization.py | 20 ++-- src/maxtext/utils/lora_utils.py | 13 ++- src/maxtext/utils/muon_utils.py | 5 +- src/maxtext/utils/standalone_checkpointer.py | 15 +-- src/maxtext/utils/vocabulary_tiling.py | 107 ++++++++++++++++++ tests/unit/train_nnx_test.py | 7 -- 28 files changed, 401 insertions(+), 139 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..d4d4c39290 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) + # This conversion script reads paxml-format weights and emits a Linen-format + # MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`, + # `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen + # tree shape). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(cfg) - if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 6e19ccc445..2b299427cc 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -561,6 +561,13 @@ logical_axis_rules: [ ['tokens_per_page', []], ['paged_kv_head_dim_size', []], # ========================================== + # Pipeline Parallelism + # ========================================== + ['layers_outside_pipeline', []], + ['layers_per_stage', []], + ['num_activations', []], + ['circular_repeats', []], + # ========================================== # Deprecated / Scheduled for Removal # ========================================== ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 406ba92523..c14d87cd4b 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): + del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration - raise ValueError("We currently don't support vocab tiling on NNX module.") def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 20594bccc3..fff07186bb 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2902,8 +2902,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: - raise ValueError("We currently don't support vocab tiling on NNX module.") + # Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py. if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..4244d199a8 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -542,29 +542,28 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ + # GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below, + # state_mesh_shardings.params, and the inference path through MaxEngine which is + # Linen-only). Run on Linen path regardless of pure_nnx; warn the user since + # NNX-format checkpoints will mismatch at restore time. + if config.pure_nnx or config_inference.pure_nnx: + max_logging.log( + "WARNING: GRPO RL trainer does not yet support pure_nnx natively; " + "running on the Linen path. NNX-format checkpoints will not load correctly here." + ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = mt.from_config(config, devices=training_devices) + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - inference_model = mt.from_config(config_inference, devices=inference_devices) + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -573,14 +572,10 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh - if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) + # create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above) + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bb0a87b5a..c00f475e8d 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -111,12 +111,12 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and Optimizer definition + # Model and Optimizer definition. + # MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params, + # state.opt_state) and serves Linen-format inference checkpoints. Use Linen path + # regardless of pure_nnx — the flag affects training, not inference serving. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -232,11 +232,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( self.config, self._mesh, init_state_fn, False ) @@ -245,11 +241,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 509e1ef7d3..66215fe011 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -525,14 +525,14 @@ def __init__( elif self.is_qwen3_hybrid: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 975e8fe9a2..3942f2fac4 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2250,9 +2250,9 @@ def __call__( w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) - # Only apply per expert scales if we have not fused with the out-projections at init time. - if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales: - wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] + # Only apply per expert scales if we have not fused with the out-projections at init time. + if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales: + wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] if self.wi_0_sparsity_module is not None: _, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 262eb62277..4cadb16701 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -545,8 +545,16 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.checkpoint re-traces the scan body during backward (remat), + # but the Linen scope retains JAX tracers from the first trace, causing + # UnexpectedTracerError. Skip checkpoint for these quantization types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + out, new_state = pure_layer_fn(state, y) + else: + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -667,7 +675,23 @@ def layer_fn(carry, scanned_vars): params = nnx_ensure_scan_leading_axis(params, length) state = nnx_ensure_scan_leading_axis(state, length) - final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates + # intermediate tracer values (amax_history float32[1024]) that escape the scan scope, + # causing UnexpectedTracerError. Use a Python for loop instead for these types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + carry = x_in + per_layer_states = [] + for i in range(length): + current_params = jax.tree.map(lambda x, i=i: x[i], params) + current_state = jax.tree.map(lambda x, i=i: x[i], state) + carry, new_state_i = layer_fn(carry, (current_params, current_state)) + per_layer_states.append(new_state_i) + final_carry = carry + scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) + else: + final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 7bb532ae7f..ab61974f7a 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -167,6 +168,39 @@ def current_linen_module() -> linen.Module | None: return None +def is_linen_initializing() -> bool: + """Check if the current execution context is inside a Linen init() call. + + Returns True when called from within a ``to_linen_class`` wrapper's + ``init()`` path. Uses :func:`current_linen_module` to access the Linen + module stack (private API already used by this module). + + This is used by NNX pipeline modules to short-circuit the full scan + during Linen init, where only the output shape/dtype is needed. + """ + module = current_linen_module() + if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing): + return module.is_initializing() + return False + + +def _refresh_variable_trace_state(module: Module) -> None: + """Refresh _trace_state for Variables that have stale trace state. + + When nnx.update() is called with tracer values from a JAX transformation + (e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which + updates the raw value but not _trace_state. This leaves Variables with a + stale _trace_state from the outer (Python) context, causing nnx.split() to + fail with "Cannot extract graph node from different trace level" errors. + + This function resets _trace_state on any Variables whose _can_update is False + so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -476,6 +510,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..35611b2166 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, scale_init=linen_initializers.zeros, diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 9401d01d9f..5f4a2f3fb6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -132,6 +133,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -189,7 +192,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index a75cefc291..6fc0e5d2f6 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -71,6 +71,7 @@ def __init__( shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, + parameter_memory_host_offload=config.parameter_memory_host_offload, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 1b0d4b4cd3..5ba365b74b 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -398,6 +398,19 @@ def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """Compute logits from hidden states (wraps NNXDecoder.apply_output_head). + + Mirrors the Linen TransformerLinenPure.logits_from_hidden_states method; + used by vocabulary tiling to recompute logits from chunked hidden states. + """ + return self.decoder.apply_output_head( + shared_embedding=self.token_embedder, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index 09c5b4e079..b743e8d4b7 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -30,6 +30,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -142,6 +143,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -202,7 +204,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..87cb4cc7ef 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -966,7 +966,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -991,7 +991,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..143bf63a07 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -139,7 +139,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -164,7 +164,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index ba9f421648..a2d76f7abd 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -72,7 +72,7 @@ from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad -from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss, vocab_tiling_nnx_loss _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules @@ -203,9 +203,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr intermediate_outputs = intermediates.to_pure_dict() if config.num_vocab_tiling > 1: - raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") - - if (config.use_indexer and not config.indexer_sparse_training) and is_train: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) + elif (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 @@ -323,7 +324,12 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: if config.use_dpo: - raise NotImplementedError("DPO for NNX modules has not been implemented.") + raise NotImplementedError( + "DPO is not yet supported for NNX modules. DPO requires a reference model " + "stored alongside the policy model (Linen path uses state.params['reference_params']); " + "the NNX TrainState equivalent has not been wired up. As a workaround, set " + "pure_nnx=False for DPO runs." + ) state = nnx.merge(model, state) # reconstruct TrainStateNNX ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] @@ -549,7 +555,11 @@ def move(path, value): if config.use_dpo: new_state = _merge_dpo_state(new_state, reference_params) return new_state, metrics - return nnx.state(new_state), metrics + # Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the + # returned state. Intermediates are transient forward-pass artifacts and must not + # persist across steps: they're absent from the abstract state used to build + # state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX. + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics def eval_step(model, config, state, data, dropout_rng=None): diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 5f4f1f03da..ea5c7cc087 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -29,6 +29,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -91,6 +92,27 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP. + + get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization, + not __call__. Activation shardings are only collected during a forward pass. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state @@ -128,7 +150,8 @@ def create_train_state_fn(): # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) model = graphdef else: # unsharded logical annotations @@ -138,10 +161,17 @@ def create_train_state_fn(): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng else: shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect activation shardings for NNX by running an abstract forward pass. + # This must happen after get_abstract_state (which uses nnx.eval_shape and only + # traces __init__, not __call__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -279,7 +309,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..0f997a6577 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -90,20 +90,17 @@ def slice_ith(input_layers): def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # This script reads a Linen-format full state and emits a Linen-format + # parameter-only checkpoint (downstream `_possibly_unroll_params` and + # `_save_decode_checkpoint` access `state.params["params"]["decoder"]` / `state.opt_state`, + # both Linen-only). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) @@ -114,12 +111,11 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # LoRA adapters and downstream `_save_decode_checkpoint`/`_possibly_unroll_params` + # are Linen-shaped; use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e1699647c6..cf84577dbd 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -71,10 +71,16 @@ def _maybe_shard_with_name(inputs, sharding_names): is_nnx = isinstance(model, nnx.Module) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: - ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) - grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, + # which rejects reduced/unreduced PartitionSpec in scan carry tensors (raises ValueError). + # Use plain params_shardings for ga_params and init_grad in the carry. + # The all-reduce for data parallelism is applied to raw_grads after the scan instead. + ga_params_shardings = params_shardings + grad_shardings = params_shardings else: ga_params_shardings = grad_shardings = params_shardings @@ -105,7 +111,7 @@ def accumulate_gradient(acc_grad_and_loss, data): if is_nnx: # Reconstruct the model using the fixed parameters (ga_params) # and the advancing non-parameter state (RNGs) from the carry. - local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) acc_grad_and_loss["rest_state"] = next_rest_state @@ -156,6 +162,11 @@ def reshape_to_microbatch_accumulations(batch_arr): + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # Apply unreduced annotation after the scan to trigger all-reduce across data-parallel + # devices (reduced/unreduced cannot be used inside jax.lax.scan carry tensors). + unreduced_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, unreduced_shardings) raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 29fa928656..a6c1c07f67 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -173,19 +173,15 @@ def __init__(self, config: Any, rng: PRNGKeyType): devices_array = maxtext_utils.create_device_mesh(config=config) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and quantization config + # Model and quantization config. + # This script produces and consumes Linen-format checkpoints (see DeepSeek*ToLinen + # layer classes used in load_and_quantize). Always use the Linen path internally, + # regardless of the pure_nnx flag — the flag affects training, not checkpoint format. self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 8554d46e3e..1efad6aa91 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common LoRA utils needed to support LoRA adapters.""" +"""Common LoRA utils needed to support LoRA adapters.""" + + from functools import partial import json import os @@ -174,11 +176,10 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # LoRA adapters are Linen-format on disk (downstream `get_lora_abstract_state` expects + # `unboxed_abstract_state.params` Linen tree shape; `lora_state.replace(params=...)` + # uses Linen TrainState API). Use the Linen init path regardless of the pure_nnx flag. + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3bd2b186b1..049a084979 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -116,6 +116,7 @@ def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) else: # Linen @@ -154,7 +155,7 @@ def get_leaf_info(leaf): print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -191,6 +192,8 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..2fc2b09e25 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -52,18 +52,15 @@ def checkpoint_loop(config, state=None): Returns: """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) + # Standalone checkpointer is a save/restore exerciser that uses + # add_entropy_to_checkpoint() to populate Linen-shaped optimizer state + # (state.opt_state, state.params). Use the Linen path regardless of pure_nnx — + # the flag affects training, not this checkpoint test harness. + model = from_config(config) mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) _, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e7b155416c..6a61f9ed23 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -247,3 +247,110 @@ def _bwd_scan_body(grad_params_acc, chunk_data): ) return total_loss, total_z_loss + + +def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): + """Calculates cross-entropy loss using vocab tiling for NNX models. + + NNX equivalent of `vocab_tiling_linen_loss`. Iterates the vocab dimension via + `jax.lax.scan` with `model.logits_from_hidden_states` per chunk; the model + carries its parameters internally so no explicit gather is needed. + + This is a memory-efficient forward (chunked logits) but uses the default + autograd path (no custom_vjp), so backward memory savings vs. the Linen + custom_vjp path are not yet realized. TODO: add a custom_vjp using + `nnx.split`/`nnx.merge` if backward memory becomes a concern. + + Args: + model: The NNX model instance (must implement `logits_from_hidden_states`). + hidden_states: The final hidden states from the decoder. + data: A dictionary containing the input data, including 'targets' and 'targets_segmentation'. + config: The model and training configuration. + is_train: A boolean indicating if the model is in training mode. + + Returns: + A tuple (total_loss, total_z_loss). + """ + labels = data["targets"] + segmentation = data["targets_segmentation"] + deterministic = not config.enable_dropout if is_train else True + model_mode = "train" + + hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), + ) + label_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length"), + ) + reshaped_hidden_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + reshaped_data_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence"), + ) + chunked_hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + chunked_data_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence",), + ) + chunked_logits_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_vocab"), + ) + + _maybe_shard_with_name = functools.partial( + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, + ) + + def _reshape(inputs, out_shape, out_sharding): + reshape_out_sharding = out_sharding if config.shard_mode == ShardMode.EXPLICIT else None + inputs = jax.lax.reshape(inputs, out_shape, out_sharding=reshape_out_sharding) + return _maybe_shard_with_name(inputs, out_sharding) + + hidden_states = _maybe_shard_with_name(hidden_states, hidden_spec) + labels = _maybe_shard_with_name(labels, label_spec) + segmentation = _maybe_shard_with_name(segmentation, label_spec) + + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + chunk_logits = model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode) + chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) + + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + + initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + return total_loss, total_z_loss diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 3495b4c557..f532820f86 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -154,13 +154,6 @@ def test_indexer_dense_warmup_skips_xent(self): self.assertEqual(float(aux["xent_sum"]), 0.0) self.assertEqual(float(loss), 0.0) - def test_vocab_tiling_raises_not_implemented(self): - cfg, ts = _build_state() - cfg.num_vocab_tiling = 4 - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) - class TestTrainStepNNX(unittest.TestCase): """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" From 10021fe27c8de38fb3fbc142fbb64274dec34d69 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 29 Apr 2026 16:07:35 +0000 Subject: [PATCH 4/4] NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo). --- src/maxtext/layers/train_state_nnx.py | 24 +- .../trainers/post_train/dpo/dpo_utils.py | 139 +++++++++++ src/maxtext/trainers/pre_train/train.py | 34 +-- src/maxtext/utils/train_utils.py | 24 +- .../integration/setup_train_loop_nnx_test.py | 9 - tests/unit/dpo_nnx_test.py | 215 ++++++++++++++++++ tests/unit/train_nnx_test.py | 10 - 7 files changed, 412 insertions(+), 43 deletions(-) create mode 100644 tests/unit/dpo_nnx_test.py diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..3f9ee1ce29 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module): This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} + {"params": {...}, "opt_state": {}...} TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} + {"model": {...}, "optimizer": {"opt_state": {...}}} + + For DPO (Direct Preference Optimization), an optional `reference_model` + carries a frozen copy of the same architecture used to compute reference + log-probabilities. Only `model` is updated by `apply_gradients`; the + reference is held alongside so it is sharded, jit-traced, and checkpointed + with the rest of the train state. """ - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + def __init__( + self, + model: nnx.Module, + optimizer: nnx.Optimizer | None, + reference_model: nnx.Module | None = None, + ): self.model = model self.optimizer = optimizer + if reference_model is not None: + self.reference_model = reference_model def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, - and increments the step counter. + and increments the step counter. Only updates `self.model`; + `self.reference_model` (if present) is left untouched. """ if self.optimizer is None: raise RuntimeError( diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py index eeda1c1a7f..fd5faa5c9c 100644 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ b/src/maxtext/trainers/post_train/dpo/dpo_utils.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +from flax import nnx + from maxtext.utils import maxtext_utils @@ -148,6 +150,8 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility } return loss, aux @@ -155,3 +159,138 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t def _merge_dpo_state(state, reference_params): """Merge reference parameters back into DPO state.""" return state.replace(params=dict(state.params, reference_params=reference_params)) + + +# NNX DPO has no split/merge counterpart: the Linen path overlays +# `reference_params` inside `state.params`, so it must be peeled off and +# reattached around `apply_gradients`. The NNX path holds the reference as a +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already +# only touches `self.model`, so no split/merge is needed. + + +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """NNX DPO loss_fn for both train and eval. + + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same + dispatcher in `gradient_accumulation_loss_and_grad`: + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` + + Differences from the Linen `dpo_loss_fn`: + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). + * `dropout_rng` and `params` are unused for NNX (kept positional for + signature parity; NNX models manage these internally). + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference + `nnx.Module`, not a `reference_params` pytree. + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows + to the reference's `nnx.Param` leaves. + + Args: + policy_model: Policy `nnx.Module` (the model being trained). + config: Config of parameters. + data: Batch of preference data with `chosen` / `rejected` fields. + dropout_rng: Unused for NNX (kept for signature parity with Linen). + params: Unused for NNX (kept for signature parity with Linen). + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. + is_train: True for train_step and False for eval_step. + + Returns: + loss: DPO preference loss + MoE load balance loss (if applicable). + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, + total_weights, moe_lb_loss, reward_accuracy. + """ + del dropout_rng, params # unused for NNX + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + + # for DPO we don't support packed sequences (they shouldn't be present in the first place) + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) + + # concatenated policy/reference forward pass + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) + + logits = policy_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=config.enable_dropout if is_train else False, + ) + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() + + ref_logits = reference_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=False, + ) + ref_logits = jax.lax.stop_gradient(ref_logits) + + # extract token ids, segmentation and logits for chosen and rejected sequences + chosen_ids = data["chosen"][..., 1:] + rejected_ids = data["rejected"][..., 1:] + chosen_segmentation = data["chosen_segmentation"][..., 1:] + rejected_segmentation = data["rejected_segmentation"][..., 1:] + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] + + # common subsequence and padding mask + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] + + # compute logratios from the sequence-reduced observed token log-probability + chosen_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_logratios = chosen_logps - chosen_ref_logps # [B] + + rejected_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_logratios = rejected_logps - rejected_ref_logps # [B] + + # DPO loss from chosen and rejected logratios + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] + losses = ( # [B] + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING + ) + total_loss, total_weights = jnp.mean(losses), losses.shape[0] + loss = total_loss + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) + aux = { + "intermediate_outputs": intermediate_outputs, + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility + } + return loss, aux diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index a2d76f7abd..264fc80d85 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -61,7 +61,7 @@ from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common import metric_logger from maxtext.common.metric_logger import record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn +from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -323,15 +323,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat params = state.params ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: - if config.use_dpo: - raise NotImplementedError( - "DPO is not yet supported for NNX modules. DPO requires a reference model " - "stored alongside the policy model (Linen path uses state.params['reference_params']); " - "the NNX TrainState equivalent has not been wired up. As a workaround, set " - "pure_nnx=False for DPO runs." - ) state = nnx.merge(model, state) # reconstruct TrainStateNNX - ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + if config.use_dpo: + # NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by + # init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors + # the Linen dpo_loss_fn signature, so it slots into the same dispatcher + # with reference_model passed as the single extra_dpo_args entry. + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model]) + else: + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] # --- Gradient computation --- if config.gradient_accumulation_steps > 1: @@ -397,9 +397,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) nnx.update(state.model, curr_params) + # `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx; + # ga_dpo carries the frozen reference_model when use_dpo, else empty). + _nnx_loss_fn = ga_fn + _nnx_extra_dpo_args = ga_dpo + def diff_wrapper(param, rest, config, data): local_model = nnx.merge(model_graphdef, param, rest, copy=True) - loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True) _, _, new_rest = nnx.split(local_model, nnx.Param, ...) return loss, (aux, new_rest) @@ -579,7 +584,10 @@ def eval_step(model, config, state, data, dropout_rng=None): loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) else: state = nnx.merge(model, state) # reconstruct TrainStateNNX - loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) + if config.use_dpo: + loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + else: + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -705,7 +713,7 @@ def train_loop(config, recorder, state=None): step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): @@ -749,7 +757,7 @@ def train_loop(config, recorder, state=None): metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ca90550630..80229b05be 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None): if config.pure_nnx: # For NNX, the train state is wrapped in the TrainStateNNX module. + # When DPO is enabled, also materialize a frozen reference model alongside + # the policy. Both are constructed by `_create_model_partial()` (which uses + # `config.init_weights_seed`), so the reference starts identical to the + # policy — standard DPO practice. The reference is later overwritten by + # the step-0 checkpoint in `setup_post_setup_state` below. def create_train_state_fn(): model = _create_model_partial() optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - return train_state_nnx.TrainStateNNX(model, optimizer) + reference_model = _create_model_partial() if config.use_dpo else None + return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model) init_state_fn = create_train_state_fn else: @@ -316,8 +322,6 @@ def create_train_state_fn(): maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: - if config.pure_nnx: - raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -342,9 +346,17 @@ def create_train_state_fn(): except FileNotFoundError: step0_restored = None if step0_restored is not None: - # TODO: For pure_nnx, the dpo state manipulation is different. - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) + if config.pure_nnx: + # step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX + # (typically from a non-DPO pre-training run, so its top-level fields are + # `model` and `optimizer` — no `reference_model`). Copy its `model` substate + # into our current state's `reference_model` slot. + step0_state = step0_restored["items"] + step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state + state["reference_model"] = step0_model_substate + else: + reference_params = step0_restored["items"].params["params"] + state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py index d11f9658a7..05a7fcffec 100644 --- a/tests/integration/setup_train_loop_nnx_test.py +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -126,15 +126,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): del model - def test_pure_nnx_dpo_raises_not_implemented(self): - """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" - # use_dpo requires a few prerequisites; the simplest is to set the flag and - # let setup_train_loop reach the NotImplementedError check before the more - # involved DPO path runs. - config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) - with self.assertRaises(NotImplementedError): - setup_train_loop(config, recorder=None) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/dpo_nnx_test.py b/tests/unit/dpo_nnx_test.py new file mode 100644 index 0000000000..461c3cb2aa --- /dev/null +++ b/tests/unit/dpo_nnx_test.py @@ -0,0 +1,215 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX DPO unit tests. + +Covers the NNX-native DPO surface: + * `TrainStateNNX(model, optimizer, reference_model=...)` — reference model + sits alongside policy and is not touched by `apply_gradients`. + * `dpo_loss_fn_nnx(policy, config, data, None, None, reference, is_train)` — + aux structure, identical-model invariant (loss = log(2), reward_accuracy = 0.5). +""" + +import math +import types +import unittest + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.post_train.dpo import dpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX transformer-shaped module for DPO tests. + + Accepts the same keyword args that `dpo_loss_fn_nnx` passes: + `decoder_input_tokens`, `decoder_positions`, `decoder_segment_ids`, + `enable_dropout`. Other args are tolerated via **kwargs. + """ + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_dpo_config(**overrides): + """Build the minimal config surface that `dpo_loss_fn_nnx` reads.""" + base = { + "dpo_label_smoothing": 0.0, + "dpo_beta": 0.1, + "enable_dropout": False, + "num_experts": 1, + "micro_batch_size_to_train_on": 2, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_dpo_batch(batch_size=2, seq_len=5): + """Build a tiny DPO-shaped batch. + + `chosen` and `rejected` share the first 2 tokens (common prefix is masked + out in the loss), differ at positions 2 and 3, and are padded at position 4. + """ + chosen = jnp.array([[1, 2, 3, 4, 0]] * batch_size, dtype=jnp.int32) + rejected = jnp.array([[1, 2, 5, 6, 0]] * batch_size, dtype=jnp.int32) + positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1)) + segmentation = jnp.array([[1, 1, 1, 1, 0]] * batch_size, dtype=jnp.int32) + return { + "chosen": chosen, + "rejected": rejected, + "chosen_position": positions, + "rejected_position": positions, + "chosen_segmentation": segmentation, + "rejected_segmentation": segmentation, + } + + +class TestTrainStateNNXWithReferenceModel(unittest.TestCase): + """`TrainStateNNX(reference_model=...)` semantics.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + self.tx = optax.adam(1e-3) + + def test_init_with_reference(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + self.assertIs(state.model, self.policy) + self.assertIs(state.reference_model, self.reference) + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_reference_omits_attribute(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer) + self.assertFalse(hasattr(state, "reference_model")) + + def test_apply_gradients_does_not_touch_reference(self): + """Gradient update on policy must leave reference model bit-identical.""" + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + + ref_kernel_before = jnp.asarray(state.reference_model.proj.kernel.value).copy() + + def policy_loss(m): + return jnp.mean(m(jnp.array([[1, 2]])) ** 2) + + grads = nnx.grad(policy_loss)(state.model) + state.apply_gradients(grads) + + ref_kernel_after = jnp.asarray(state.reference_model.proj.kernel.value) + self.assertTrue(jnp.array_equal(ref_kernel_before, ref_kernel_after)) + + +class TestDPOLossFnNNX(unittest.TestCase): + """`dpo_loss_fn_nnx` numerical and structural sanity checks.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Reference initialized with the same seed to make policy and reference + # bit-identical at construction time. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_dpo_config() + self.data = _make_dpo_batch() + + def test_aux_has_expected_keys(self): + _, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + expected_keys = { + "intermediate_outputs", + "xent_sum", + "dpo_loss", + "total_weights", + "moe_lb_loss", + "reward_accuracy", + "indexer_loss", + "mtp_loss", + } + self.assertEqual(set(aux.keys()), expected_keys) + self.assertEqual(aux["xent_sum"], 0.0) + self.assertEqual(aux["moe_lb_loss"], 0.0) # num_experts=1 + self.assertEqual(aux["total_weights"], self.data["chosen"].shape[0]) + + def test_identical_policy_and_reference_yields_log2_loss(self): + """When policy == reference, all logratios are 0; with label_smoothing=0 + the per-example loss is `-log(sigmoid(0)) = log(2)`. `reward_accuracy` + uses strict `chosen > rejected`, so equal logratios score 0.0 (no example + is strictly preferred). + """ + loss, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + self.assertAlmostEqual(float(loss), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["dpo_loss"]), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["reward_accuracy"]), 0.0, places=4) + + def test_dropout_rng_and_params_args_are_unused(self): + """The 4th and 5th positional args are signature-compat slots for the + Linen dispatcher; passing arbitrary values must not affect the result. + """ + loss_a, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + loss_b, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, + self.config, + dict(self.data), + jax.random.PRNGKey(123), # dropout_rng — unused + {"params": "garbage"}, # params — unused + self.reference, + is_train=True, + ) + self.assertAlmostEqual(float(loss_a), float(loss_b), places=6) + + def test_value_and_grad_argnums0_only_diffs_policy(self): + """`nnx.value_and_grad(..., argnums=0)` over the policy should produce + finite grads on policy params and not require reference grads. + """ + + def _loss(policy_module): + loss, _ = dpo_utils.dpo_loss_fn_nnx( + policy_module, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + return loss + + grad_fn = nnx.value_and_grad(_loss, argnums=0) + loss, grads = grad_fn(self.policy) + self.assertTrue(jnp.isfinite(loss)) + # Grads is an nnx.State of the policy's nnx.Param leaves; check at least one + # leaf is finite and non-trivially shaped. + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index f532820f86..4340d4e22a 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -174,16 +174,6 @@ def test_train_step_returns_state_and_metrics(self): self.assertIn("learning/param_norm", metrics["scalar"]) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) - def test_train_step_dpo_raises_for_nnx(self): - cfg, ts = _build_state() - cfg.use_dpo = True - state_graphdef, state_pure = nnx.split(ts) - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.train_step( - state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data - ) - def test_train_step_increments_optimizer_step(self): cfg, ts = _build_state() state_graphdef, state_pure = nnx.split(ts)