diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 888cf4d2d1..9b5f0cfb21 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -35,6 +35,7 @@ """ import argparse +import functools import gc import os import sys @@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name mesh = Mesh(devices_array, cfg.mesh_axes) quant = quantizations.configure_quantization(cfg) - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if cfg.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) + if cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 77751479ce..ca02970dd2 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1087,8 +1087,9 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false -pure_nnx_decoder: false +enable_nnx: False +pure_nnx_decoder: False +pure_nnx: False ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 21296e965d..0f30471c79 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -794,6 +794,7 @@ class HardwareAndMesh(BaseModel): optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") + pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.") class LayoutAndSharding(BaseModel): diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 100434ef74..28eef21cb0 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -546,23 +546,43 @@ def setup_train_loop( max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - model = mt.from_config(config, devices=training_devices) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - inference_model = mt.from_config(config_inference, devices=inference_devices) + if config_inference.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + data_iterator, config, mesh, checkpoint_manager, init_state_fn ) # create inference_state_mesh_shardings from inference_mesh + if config_inference.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( - inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False + config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage @@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None): data_buffer = [] data_buffer_lock = threading.Lock() - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) inference_prof = profiler.Profiler(config_inference, offset_step=start_step) data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 02a2f392c2..23cd2387db 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( - self.model, None, self.config, rng, self._mesh, False + self.config, self._mesh, init_state_fn, False ) # reshard given params based on shardings from config in MaxEngine params = jax.device_put(params, state_mesh_shardings.params) state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - state, self.state_mesh_annotations = maxtext_utils.setup_decode_state( - self.model, self.config, rng1, self._mesh, None - ) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py new file mode 100644 index 0000000000..9ef0e6dffd --- /dev/null +++ b/src/maxtext/layers/train_state_nnx.py @@ -0,0 +1,48 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" The NNX Unified TrainState. """ + +from typing import Any + +from flax import nnx + + +class TrainStateNNX(nnx.Module): + """ + A unified container for NNX models and optimizers. + This replaces Linen's TrainState for checkpointing. + + Linen TrainState pytree: + {“params”: {...}, “opt_state”: {}...} + TrainStateNNX state pytree: + {“model”: {...}, “optimizer”: {“opt_state”: {...}} + """ + + def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + self.model = model + self.optimizer = optimizer + + def apply_gradients(self, grads: Any): + """ + Mimics the Linen apply_gradients function. + Updates the optimizer state, applies updates to parameters, + and increments the step counter. + """ + if self.optimizer is None: + raise RuntimeError( + "Cannot call apply_gradients on a TrainStateNNX initialized without an optimizer. " + "This usually happens when the state was created for inference only." + ) + self.optimizer.update(self.model, grads) diff --git a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py index 7cc8f5b658..c7f6bd4740 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py +++ b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py @@ -85,7 +85,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) data_loader = DataLoader(config, mesh, data_iterator, recorder) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 4b3505b224..b0e4d8b690 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -75,8 +75,10 @@ VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() -def get_first_step(state): - return int(state.step) +def get_first_step(model, state): + if isinstance(model, nn.Module): + return int(state.step) + return int(state.optimizer.step.get_value()) # ----------------------------------------------------------------------------- @@ -512,7 +514,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 408340016e..15af61a572 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -27,6 +27,7 @@ from typing import Sequence from absl import app +from flax import nnx from flax.linen import partitioning as nn_partitioning import jax from jax.experimental.serialize_executable import serialize @@ -36,6 +37,7 @@ from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations +from maxtext.layers import train_state_nnx from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.trainers.diloco import diloco @@ -44,6 +46,8 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils # pylint: disable=too-many-positional-arguments @@ -93,7 +97,10 @@ def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) - model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if config.pure_nnx: + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) + else: + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon @@ -103,18 +110,39 @@ def get_shaped_inputs(topology_mesh, config): _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) - # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, example_rng, topology_mesh - ) + if config.pure_nnx: + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) - # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) + # Shaped state + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True) + + if config.pure_nnx: + # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings. + logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + # For NNX, get_functional_train_with_signature expects the graphdef (static structure), + # not the raw model — mirroring how the training loop does nnx.split(train_state). + with nn_partitioning.axis_rules(config.logical_axis_rules): + graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + model = graphdef + else: + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) - shaped_train_args = (abstract_state, shaped_batch, shaped_rng) + if config.pure_nnx: + shaped_train_args = (abstract_state, shaped_batch) + else: + shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -277,12 +305,20 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(topology_mesh) - maxtext_utils.print_shardings_params( - shaped_train_args[0].params, - state_mesh_shardings.params, - topology_mesh, - logical_annotations.params, - ) + if config.pure_nnx: + maxtext_utils.print_shardings_params( + shaped_train_args[0], + state_mesh_shardings, + topology_mesh, + logical_annotations, + ) + else: + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + logical_annotations.params, + ) # Compile print("Jitting and compiling train step...", flush=True) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 7c520cc470..2fd14b87a2 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -22,6 +22,7 @@ The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16. """ +import functools import os.path from typing import Sequence @@ -42,8 +43,6 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -Transformer = models.transformer_as_linen - def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): """Unroll scanned input layers when force_unroll is set.""" @@ -93,12 +92,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( - model, None, tx, config, rng, mesh, checkpoint_manager + None, config, mesh, checkpoint_manager, init_state_fn ) num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") @@ -109,7 +116,10 @@ def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 4be05ff7e1..36e612a3f9 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -30,6 +30,7 @@ """ +import functools import os from typing import Any, Sequence @@ -174,12 +175,19 @@ def __init__(self, config: Any, rng: PRNGKeyType): # Model and quantization config self.quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, None, self.config, self.rng, self._mesh, False - ) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + + self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) def load_and_quantize(self) -> None: """ diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 03095edd73..24099ef22a 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -14,6 +14,7 @@ """ Common LoRA utils needed to support LoRA adapters.""" +from functools import partial import json import jax @@ -166,7 +167,12 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, rng, mesh, True) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index dab8103a4f..5700cbf8dc 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -18,6 +18,7 @@ import functools import pickle import os +from typing import Sequence from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -27,6 +28,7 @@ from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load +from jax.sharding import AxisType, Mesh import jax import jax.numpy as jnp @@ -36,7 +38,8 @@ import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.configs import pyconfig +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode from maxtext.configs import types from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing @@ -196,8 +199,11 @@ def get_train_input_output_trees(func, input_args, input_kwargs): serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file) shaped_batch = get_shaped_batch(config) - example_rng = jax.random.PRNGKey(0) - shaped_input_args = (state, shaped_batch, example_rng) + if config.pure_nnx: + shaped_input_args = (state, shaped_batch) + else: + example_rng = jax.random.PRNGKey(0) + shaped_input_args = (state, shaped_batch, example_rng) shaped_input_kwargs = {} in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs) p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree, execution_devices=execution_devices) @@ -1050,14 +1056,13 @@ def get_abstract_param(model, config): return abstract_vars -def setup_decode_state(model, config, rng, mesh, checkpoint_manager): +def setup_decode_state(config, mesh, checkpoint_manager, init_state_fn): """Setup decode state by loading params from a checkpoint. Args: - model: the flax model to initialize config: config object - rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: Checkpoint manager + init_state_fn: function to initialize the model state Returns: state: state with decode params loaded from the checkpoint @@ -1067,12 +1072,12 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): # generate random params max_logging.log("No decode checkpoint specified - generating random weights.") state, state_mesh_annotations, _, _ = setup_initial_state( - model, None, None, config, rng, mesh, checkpoint_manager, False + None, config, mesh, checkpoint_manager, init_state_fn, False ) else: # Load params from checkpoint max_logging.log(f"Loading decode params from {config.load_parameters_path}") - unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False) + unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(config, mesh, init_state_fn, False) with nn_partitioning.axis_rules(config.logical_axis_rules): params = checkpointing.load_params_from_path( config.load_parameters_path, @@ -1087,40 +1092,35 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): return state, state_mesh_annotations -def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): +def setup_training_state(data_iterator, config, mesh, checkpoint_manager, init_state_fn): is_training = True return setup_initial_state( - model, data_iterator, - tx, config, - rng, mesh, checkpoint_manager, + init_state_fn, is_training, ) def setup_initial_state( - model, data_iterator, - tx, config, - rng, mesh, checkpoint_manager, + init_state_fn, is_training=True, ): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: - model: the flax model to initialize - tx: the optax.GradientTransformation + data_iterator: data iterator config: config object - rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: an Orbax checkpointing.CheckpointManager object + init_state_fn: function to initialize the training state is_training: True to initialize training state, False for decode state Returns: @@ -1129,7 +1129,7 @@ def setup_initial_state( """ unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( - model, tx, config, rng, mesh, is_training + config, mesh, init_state_fn, is_training ) # Initialization @@ -1164,14 +1164,14 @@ def setup_initial_state( # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] else: - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" # pylint: disable=not-callable state = jax.jit( init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings, - )(rng) + )() if raw_params: # If we loaded a partial state, we need to merge it. state = state.replace(params=raw_params) @@ -1180,8 +1180,8 @@ def setup_initial_state( return state, state_mesh_annotations, state_mesh_shardings, data_iterator -def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) +def get_logical_annotations(config, mesh, init_state_fn): + init_state_partial = init_state_fn with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) @@ -1189,9 +1189,9 @@ def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): return logical_annotations -def get_abstract_state(model, tx, config, rng, mesh, is_training=True): +def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) @@ -1524,3 +1524,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging all_host_upload=False, # Only upload from lead host (Host 0) ) + + +def get_mesh_from_config( + config: pyconfig.HyperParameters, + devices: Sequence[jax.Device] | None = None, +) -> Mesh: + """ + Geh mesh from the configuration. + + Args: + config: the configuration + devices: the devices + + Returns: + the device mesh + """ + devices_array = create_device_mesh(config, devices) + + if config.shard_mode == ShardMode.EXPLICIT: + axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) + else: + axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) + + return Mesh(devices_array, config.mesh_axes, axis_types=axis_types) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py new file mode 100644 index 0000000000..7378928ef2 --- /dev/null +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -0,0 +1,172 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Utils for MaxText NNX. """ + +from functools import partial +from typing import Callable + +from flax import nnx +import jax +from jax.sharding import Mesh, NamedSharding + +from maxtext.utils import max_logging +from maxtext.configs import pyconfig + + +def create_nnx_rngs( + config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None +) -> nnx.Rngs: + """ + Create NNX Rngs + + Args: + config: the configuration + is_training: if the Rngs are for training + rng_key: the Rng key + + Returns: + The NNX Rngs + """ + if rng_key is None: + rng_key = jax.random.PRNGKey(config.init_weights_seed) + + if is_training: + return nnx.Rngs( + params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2) + ) + return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference + + +def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State: + """Get named sharding from NNX abstract state. + + Args: + abstract_state: NNX model abstract state created from nnx.get_abstract_model. + + Returns: + named sharding structure + """ + # Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we + # get the existing sharding from the abstract_state. + # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) + return jax.tree.map( + lambda x: x.sharding, + abstract_state, + is_leaf=lambda x: isinstance(x, jax.ShapeDtypeStruct), + ) + + +def get_partition_spec_nnx(named_sharding: nnx.State) -> nnx.State: + """Get mesh partition spec from named sharding. + + Args: + named_sharding: NNX model named sharding. + + Returns: + mesh partition spec + """ + # The leaf is of type NamedSharding. + return jax.tree.map( + lambda x: x.spec, + named_sharding, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + + +def set_named_sharding_nnx(abstract_state: nnx.State, named_sharding: nnx.State) -> nnx.State: + """Set named sharding to NNX abstract state. + + Args: + abstract_state: NNX model abstract state created from nnx.get_abstract_model(). + named_sharding: named sharding. It must have the same tree structure with abstract_state. + + Returns: + updated abstract_state + """ + return jax.tree.map(lambda x, y: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=y), abstract_state, named_sharding) + + +def move_memory_to_host(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: + """ + Change the memory_kind of the NamedSharding to "pinned_host". This function can be + called by jax.tree_util.tree_map_with_path on a NNX state structure. + + Args: + path: the tree path tuple + x: the NamedSharding corresponding to the path + + Returns: + the NamedSharding with memory_kind set to "pinned_host" + """ + max_logging.log(f"max_utils.py: Moving {path} to host") + # Create the new sharding with the target memory kind + return x.with_memory_kind(kind="pinned_host") + + +def move_memory_to_device(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: + """ + Change the memory_kind of the NamedSharding to "device". This function can be + called by jax.tree_util.tree_map_with_path on a NNX state structure. + + Args: + path: the tree path tuple + x: the NamedSharding corresponding to the path + + Returns: + the NamedSharding with memory_kind set to "device" + """ + max_logging.log(f"max_utils.py: Moving {path} to device") + # Create the new sharding with the target memory kind + return x.with_memory_kind(kind="device") + + +def create_nnx_sharded_model( + abstract_model: nnx.Module, + init_fn: Callable, + mesh: Mesh | None = None, + named_sharding: nnx.State | None = None, +) -> nnx.Module: + """ + Create the model with the given sharding. + + Args: + abstract_model: the abstract model + init_fn: the model init function + mesh: the device mesh + named_sharding: the given sharding + + Returns: + The initialized sharded model + """ + graphdef, abstract_state = nnx.split(abstract_model) + if named_sharding is None: + # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) + # we get the sharding directly from it. + named_sharding = get_named_sharding_nnx(abstract_state) + + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + @partial(jax.jit, out_shardings=named_sharding) + def create_sharded_state(): + model = init_fn() + return jax.lax.with_sharding_constraint(nnx.state(model), named_sharding) + + # Create the model with sharded parameters. + with jax.set_mesh(mesh): + sharded_state = create_sharded_state() + return nnx.merge(graphdef, sharded_state) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index b3057d0518..d96e0d6543 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -18,18 +18,16 @@ from collections.abc import Sequence from functools import partial from typing import overload - from etils import epath from flax import nnx import flax.linen as nn import jax -from jax.sharding import AxisType, Mesh +from jax.sharding import Mesh from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import max_utils -from maxtext.utils import maxtext_utils +from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx from orbax import checkpoint as ocp @@ -40,6 +38,7 @@ def from_config( mesh: Mesh | None = None, *, model_mode: str = MODEL_MODE_TRAIN, + rngs: None = None, ) -> nn.Module: ... @@ -80,15 +79,7 @@ def from_config( model = from_config(config) """ if mesh is None: - devices_array = maxtext_utils.create_device_mesh(config, devices) - - if config.shard_mode == ShardMode.EXPLICIT: - axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) - else: - axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) - - mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types) - + mesh = maxtext_utils.get_mesh_from_config(config, devices) model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) # Return only the model @@ -114,16 +105,10 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + is_training = model_mode == MODEL_MODE_TRAIN def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) - - if model_mode == MODEL_MODE_TRAIN: - rngs = nnx.Rngs(params=rng_key, dropout=1) - else: - rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference - + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) @@ -136,6 +121,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, if mesh is None: mesh = abstract_model.mesh + # Note for pure_nnx: + # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and + # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen + # LogicallyPartitioned structure. + # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned + # structure in the abstract state and we can get the sharded state with the following code: + # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) + # abstract_model = nnx.merge(graphdef, state) + # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) + # sharded_state = nnx.state(model) + # JIT a function that creates the model state with proper sharding from the start. # By providing out_shardings, we instruct JAX to produce sharded output directly, # avoiding a large intermediate allocation on a single device. diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 00eb408ad3..1d0e37e29b 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -16,6 +16,8 @@ """ Utils that are only interesting for training in MaxText. """ import os +from functools import partial + import jax import functools from flax.linen import partitioning as nn_partitioning @@ -33,12 +35,17 @@ from maxtext.trainers.diloco import diloco -def create_training_tools(config, model, mesh): - """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" - init_rng = jax.random.PRNGKey(config.init_weights_seed) +def create_training_optimizer(config, model): + """Creates the optimizer and learning rate schedule.""" learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon tx = optimizers.get_optimizer(config, learning_rate_schedule, model) + return learning_rate_schedule, tx + + +def create_checkpoint_manager(config, mesh, init_state_fn): + """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" + # pass in model for muon logger = checkpointing.setup_checkpoint_logger(config) if config.enable_multi_tier_checkpointing: checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( @@ -47,7 +54,7 @@ def create_training_tools(config, model, mesh): mesh, ) elif config.enable_emergency_checkpoint: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( config.local_checkpoint_directory, config.checkpoint_dir, @@ -84,10 +91,10 @@ def create_training_tools(config, model, mesh): config.enable_single_replica_ckpt_restoring, ) - return init_rng, checkpoint_manager, learning_rate_schedule, tx + return checkpoint_manager -def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): +def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=None): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" if config.enable_diloco: functional_train = train_step @@ -109,7 +116,9 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit if config.compiled_trainstep_file != "": max_logging.log("Loading the compiled function...") - execution_devices = model.mesh.devices.flatten().tolist() + # For NNX, model is the GraphDef (no .mesh); use the mesh passed explicitly instead. + execution_mesh = mesh if mesh is not None else model.mesh + execution_devices = execution_mesh.devices.flatten().tolist() # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state, execution_devices) max_logging.log("Loaded compiled function!") @@ -164,7 +173,9 @@ def jit_train_and_eval_step( train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) train_step = diloco.build_diloco_train_step(config, train_step_partial) data_sharding = sharding.get_input_data_sharding(config, mesh) - p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) + p_train_step = jit_train_step( + config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh + ) p_eval_step = None if eval_data_iterator: p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) @@ -196,9 +207,21 @@ def setup_train_loop(config, recorder, devices=None): from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): - model = model_creation_utils.from_config(config, devices) + is_training = True + init_rng = jax.random.PRNGKey(config.init_weights_seed) + if config.pure_nnx: + # Create abstract NNX model. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = model_creation_utils.from_config(config, devices) mesh = model.mesh - init_rng, checkpoint_manager, learning_rate_schedule, tx = create_training_tools(config, model, mesh) + learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) + checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) @@ -224,7 +247,7 @@ def setup_train_loop(config, recorder, devices=None): ) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + data_iterator, config, mesh, checkpoint_manager, init_state_fn ) if config.enable_diloco: @@ -247,14 +270,14 @@ def setup_train_loop(config, recorder, devices=None): # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) max_utils.print_non_trivial_mesh_axis(model.mesh) maxtext_utils.print_shardings_params( state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params ) if config.use_dpo: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) diff --git a/tests/assets/logits_generation/generate_grpo_golden_logits.py b/tests/assets/logits_generation/generate_grpo_golden_logits.py index e4e9f4fe8a..cae8b9e4d3 100644 --- a/tests/assets/logits_generation/generate_grpo_golden_logits.py +++ b/tests/assets/logits_generation/generate_grpo_golden_logits.py @@ -38,7 +38,7 @@ from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.utils import maxtext_utils -from tests.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs +from tests.post_training.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs import numpy as np import torch import transformers @@ -73,17 +73,27 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) # With checkpoint - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) + if self.cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) + self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules) self.data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # Without checkpoint - self.model_no_ckpt_loading = models.transformer_as_linen( - config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN - ) - self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state( - self.model_no_ckpt_loading, self.cfg_no_ckpt_loading, self.rng, mesh, None - ) + if self.cfg_no_ckpt_loading.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model_no_ckpt_loading = models.transformer_as_linen( + config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial( + maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng + ) + self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state(self.cfg_no_ckpt_loading, mesh, None, init_state_fn) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", diff --git a/tests/post_training/integration/grpo_correctness.py b/tests/post_training/integration/grpo_correctness.py index 44a3e28df7..adefc03a7e 100644 --- a/tests/post_training/integration/grpo_correctness.py +++ b/tests/post_training/integration/grpo_correctness.py @@ -13,6 +13,7 @@ # limitations under the License. """GRPO correctness tests""" +import functools import os import unittest @@ -60,8 +61,13 @@ def setUp(self): self.rng = jax.random.PRNGKey(42) devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) + if self.cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) + self.state, _ = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", add_bos_token=False, @@ -121,7 +127,7 @@ def _prepare_maxtext_inputs(self): ) def _prepare_trl_inputs(self): - """Prepare TRL inputs.""" + """Prepare inputs for TRL model.""" tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt") input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1) attention_mask = torch.cat( diff --git a/tests/post_training/integration/grpo_trainer_correctness_test.py b/tests/post_training/integration/grpo_trainer_correctness_test.py index 5994ab4b31..24a3ceaf1d 100644 --- a/tests/post_training/integration/grpo_trainer_correctness_test.py +++ b/tests/post_training/integration/grpo_trainer_correctness_test.py @@ -25,6 +25,7 @@ pytest tests/post_training/integration/grpo_trainer_correctness_test.py """ +import functools import os import subprocess import sys @@ -72,8 +73,13 @@ def setup_maxtext_model(config, mesh): init_rng = jax.random.PRNGKey(config.init_weights_seed) quant = quantizations.configure_quantization(config) - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, state_mesh_annotations = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) + state, state_mesh_annotations = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, config.logical_axis_rules) data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) reference_params = jax.tree.map(jnp.copy, state.params["params"]) diff --git a/tests/post_training/integration/sft_trainer_correctness_test.py b/tests/post_training/integration/sft_trainer_correctness_test.py index 9ed48a0492..aeb7c77bfc 100644 --- a/tests/post_training/integration/sft_trainer_correctness_test.py +++ b/tests/post_training/integration/sft_trainer_correctness_test.py @@ -24,6 +24,7 @@ pytest tests/post_training/integration/sft_trainer_correctness_test.py """ +import functools import os.path import subprocess import sys @@ -117,8 +118,13 @@ def setup_maxtext_model(config): quant = quantizations.configure_quantization(config) devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) return maxtext_model, state, init_rng diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py new file mode 100644 index 0000000000..0eb1f7ef77 --- /dev/null +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -0,0 +1,182 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for the common MaxText NNX utilities """ +import unittest +from dataclasses import dataclass +from typing import Any +import jax +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental import mesh_utils + +from maxtext.utils import maxtext_utils_nnx + + +class TestMaxTextUtilsNNX(unittest.TestCase): + """Test the functions for MaxText Utils.""" + + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" + + init_weights_seed: int = 42 + + class TinyModel(nnx.Module): + """ + A tiny NNX model with logical annotations. + Annotations are required to test that sharding extraction logic works. + """ + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + jax.device_count(), + jax.device_count(), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("data", None)), + # FIX: Removed () from zeros. zeros is the initializer function itself, + # not a factory like lecun_normal(). + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("data",)), + rngs=rngs, + ) + + def tiny_model_init_fn(self): + """Factory function for model initialization.""" + return self.TinyModel(rngs=nnx.Rngs(0)) + + def setUp(self): + # Create a mesh for sharding tests. + # NamedSharding requires an active Mesh to resolve logical names. + self.devices = mesh_utils.create_device_mesh((jax.device_count(),)) + self.mesh = Mesh(self.devices, axis_names=("data",)) + + def test_create_nnx_rngs_training(self): + # Using Any to satisfy static type checkers for the MockConfig + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=True) + + self.assertIsInstance(rngs, nnx.Rngs) + # FIX: nnx.Rngs does not have a .streams attribute. + # Check for stream attributes directly on the object. + self.assertTrue(hasattr(rngs, "params")) + self.assertTrue(hasattr(rngs, "dropout")) + self.assertTrue(hasattr(rngs, "aqt")) + + def test_create_nnx_rngs_inference(self): + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=False) + + self.assertIsInstance(rngs, nnx.Rngs) + # Check that 'params' exists but 'dropout' and 'aqt' were excluded + self.assertTrue(hasattr(rngs, "params")) + self.assertFalse(hasattr(rngs, "dropout")) + self.assertFalse(hasattr(rngs, "aqt")) + + def test_move_memory(self): + sharding = NamedSharding(self.mesh, P("data")) + self.assertNotEqual(sharding.memory_kind, "pinned_host") + + path = ("layers", "linear", "kernel") + host_sharding = maxtext_utils_nnx.move_memory_to_host(path, sharding) + + self.assertEqual(host_sharding.memory_kind, "pinned_host") + self.assertEqual(host_sharding.spec, P("data")) + + device_sharding = maxtext_utils_nnx.move_memory_to_device(path, sharding) + + self.assertEqual(device_sharding.memory_kind, "device") + self.assertEqual(device_sharding.spec, P("data")) + + def test_get_set_named_sharding_nnx(self): + # 1. Create the abstract state using standard NNX functional API + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + + # 2. Test extraction + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # Verify kernel and bias match the P("data") annotations from TinyModel + self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) + self.assertEqual(extracted_shardings.linear.bias.get_value().spec, P("data")) + + # Target kernel spec update + new_kernel_spec = P(None, "data") + + def update_spec_fn(path, leaf_sharding): + path_str = jax.tree_util.keystr(path) + if "linear" in path_str and "kernel" in path_str: + # Construct a new NamedSharding with the requested logical spec + return NamedSharding(leaf_sharding.mesh, new_kernel_spec) + return leaf_sharding + + # Apply the spec change to the extracted sharding tree + extracted_shardings = jax.tree.map_with_path(update_spec_fn, extracted_shardings) + + # 3. Test setting new shardings + # Transform the extracted shardings to host memory + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + updated_abstract = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, new_shardings) + + # Verify the metadata inside the abstract state leaf has updated its sharding + self.assertEqual(updated_abstract.linear.kernel.sharding.memory_kind, "pinned_host") + # Also verify the spec was updated successfully + self.assertEqual(updated_abstract.linear.kernel.sharding.spec, new_kernel_spec) + + # 4. Verify named sharding is preserved after NNX merge (update) and split (state) + model = self.tiny_model_init_fn() + nnx.update(model, updated_abstract) + re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) + + # Verify kernel and bias have expected sharding + self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) + self.assertEqual(re_extracted_shardings.linear.bias.get_value().spec, P("data")) + + def test_create_nnx_sharded_model(self): + # 1. Create abstract model + graphdef, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + abstract_model = nnx.merge(graphdef, abstract_state) + + # 2. Modify shardings to trigger host offloading + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + + # 3. Run the sharded creation + # We pass the abstract model and use the custom sharding for instantiation + sharded_model = maxtext_utils_nnx.create_nnx_sharded_model( + abstract_model, self.tiny_model_init_fn, mesh=self.mesh, named_sharding=new_shardings + ) + + # 4. Verify the model is concrete (contains Arrays) and sharded on host + self.assertIsInstance(sharded_model.linear.kernel[...], jax.Array) + self.assertEqual(sharded_model.linear.kernel[...].sharding.memory_kind, "pinned_host") + + def test_get_partition_spec_nnx(self): + """Verifies extraction of PartitionSpecs from NamedShardings.""" + # 1. Create abstract state and get sharding + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # 2. Execute extraction + spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) + + # 3. Verify that the leaves are now raw PartitionSpecs + # Expected values derived from TinyModel definition + expected_spec_k = P("data", None) + expected_spec_b = P("data") + + self.assertEqual(spec["linear"]["kernel"], expected_spec_k) + self.assertEqual(spec["linear"]["bias"], expected_spec_b) + self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index a65a905c7f..4b03da2a80 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -14,10 +14,12 @@ """Tests for the common MaxText utilities""" +import functools +from typing import Any, Sequence from collections.abc import Callable -from typing import Any import unittest -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch +from dataclasses import dataclass, field from flax import linen as nn from flax import nnx @@ -26,9 +28,9 @@ import jax from jax import random, vmap import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models @@ -351,18 +353,31 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) quant = quantizations.configure_quantization(self.config) - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) def test_setup_decode_state(self): rng = random.PRNGKey(0) - state, _ = maxtext_utils.setup_decode_state(self.model, self.config, rng, self.mesh, None) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + state, _ = maxtext_utils.setup_decode_state(self.config, self.mesh, None, init_state_fn) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) - state, _, _, _ = maxtext_utils.setup_initial_state(self.model, None, tx, self.config, rng, self.mesh, None) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) + state, _, _, _ = maxtext_utils.setup_initial_state(None, self.config, self.mesh, None, init_state_fn) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) @@ -908,38 +923,65 @@ def test_wsd_schedule(self): self.assertIn("wsd_decay_steps_fraction", str(cm.exception)) -class TestGetAbstractState(unittest.TestCase): - """Test class for get_abstract_state.""" +class TestMeshUtils(unittest.TestCase): + """Test suite for the mesh creation utility function.""" - def setUp(self): - extra_args = get_decoupled_parallelism_overrides() - self.config = pyconfig.initialize( - [None, get_test_config_path()], - **extra_args, - enable_checkpointing=False, - model_name="llama3.1-8b", - per_device_batch_size=1, - max_target_length=16, - ) - devices_array = maxtext_utils.create_device_mesh(self.config) - self.mesh = Mesh(devices_array, self.config.mesh_axes) - quant = quantizations.configure_quantization(self.config) - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - self.rng = jax.random.PRNGKey(0) - self.tx = optax.adam(learning_rate=0.001) - - def test_get_abstract_state(self): - """Tests that get_abstract_state returns abstract arrays.""" - # get_abstract_state returns a tuple, the first element is the abstract state. - abstract_state, _, _ = maxtext_utils.get_abstract_state(self.model, self.tx, self.config, self.rng, self.mesh, None) + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" - # Check that params are abstract - param_leaves = jax.tree_util.tree_leaves(abstract_state.params) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in param_leaves)) + init_weights_seed: int = 42 + shard_mode: str = ShardMode.EXPLICIT + mesh_axes: Sequence[str] = field(default_factory=lambda: ["data", "model"]) - # Check that opt_state is abstract - opt_state_leaves = jax.tree_util.tree_leaves(abstract_state.opt_state) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves)) + def setUp(self): + # Setup a dummy device array for the mock to return + self.devices_array = np.array(jax.devices()) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_explicit_mode(self, mock_create_device_mesh): + """Tests that ShardMode.EXPLICIT sets axis_types to MANUAL.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:1].reshape((1,)) + config = self.MockConfig(shard_mode=ShardMode.EXPLICIT, mesh_axes=["data"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + # Check that the internal utility was called correctly + mock_create_device_mesh.assert_called_once_with(config, None) + + # Verify Mesh properties + self.assertEqual(mesh.axis_names, ("data",)) + # In JAX, AxisType.MANUAL is the equivalent for explicit control + self.assertEqual(mesh.axis_types, (AxisType.Explicit,)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_auto_mode(self, mock_create_device_mesh): + """Tests that ShardMode.AUTO sets axis_types to AUTO.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:2].reshape((2, 1)) + config = self.MockConfig(shard_mode=ShardMode.AUTO, mesh_axes=["data", "model"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + self.assertEqual(len(mesh.axis_types), 2) + self.assertTrue(all(t == AxisType.Auto for t in mesh.axis_types)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_with_provided_devices(self, mock_create_device_mesh): + """Tests that provided devices are passed through to the mesh creator.""" + config = self.MockConfig() + specific_devices = self.devices_array[:2].reshape((1, 2)) + mock_create_device_mesh.return_value = specific_devices + + _ = maxtext_utils.get_mesh_from_config(config, devices=specific_devices) + + # Verify the second argument to create_device_mesh was our device list + mock_create_device_mesh.assert_called_once_with(config, specific_devices) if __name__ == "__main__": diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 2cd696f241..c9e4deb725 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -14,6 +14,7 @@ """Compare expected sharding of models with actual sharding of models.""" +import functools import hashlib import json import os @@ -127,6 +128,9 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) f"model_name={model_name}", "log_config=false", "debug_sharding=true", # for input sharding dump + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] root_dir = "tests/utils/sharding_info" @@ -215,6 +219,9 @@ def abstract_state_and_shardings(request): f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", "weight_dtype=float32", + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] config = pyconfig.initialize(params) validate_config(config) @@ -228,13 +235,15 @@ def abstract_state_and_shardings(request): tx = optimizers.get_optimizer(config, learning_rate_schedule) rng = jax.random.PRNGKey(0) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # Get abstract state and physical shardings from maxtext_utils abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, rng, topology_mesh, is_training=True + config, topology_mesh, init_state_fn, is_training=True ) # Get logical shardings from maxtext_utils - logical_shardings = maxtext_utils.get_logical_annotations(model, tx, config, rng, topology_mesh, is_training=True) + logical_shardings = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings diff --git a/tests/unit/state_dtypes_test.py b/tests/unit/state_dtypes_test.py index 77e166193a..10db1bf199 100644 --- a/tests/unit/state_dtypes_test.py +++ b/tests/unit/state_dtypes_test.py @@ -13,6 +13,7 @@ # limitations under the License. """ Test that all weights are expected dtype (default float32) """ +from functools import partial import unittest import jax @@ -47,7 +48,12 @@ def get_state(self, argv): tx = optimizers.get_optimizer(config, learning_rate_schedule) _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, example_rng, mesh) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) return abstract_state def get_weights(self, argv): diff --git a/tests/unit/train_utils_test.py b/tests/unit/train_utils_test.py new file mode 100644 index 0000000000..a8b9458794 --- /dev/null +++ b/tests/unit/train_utils_test.py @@ -0,0 +1,196 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for train_utils.py.""" + +import unittest +from dataclasses import dataclass +from unittest.mock import MagicMock + +from maxtext.utils.train_utils import validate_train_config, create_training_optimizer + + +@dataclass +class MockConfig: + """Minimal mock config for validate_train_config tests.""" + + run_name: str = "test_run" + dataset_path: str = "gs://test-bucket/data" + base_output_directory: str = "gs://test-bucket/output" + steps: int = 100 + quantization: str = "" + gradient_accumulation_steps: int = 1 + packing: bool = False + dataset_type: str = "tfds" + + # Fields needed for create_training_optimizer + opt_type: str = "adamw" + adam_b1: float = 0.9 + adam_b2: float = 0.95 + adam_eps: float = 1e-8 + adam_eps_root: float = 0.0 + adam_weight_decay: float = 0.1 + mu_dtype: str = "" + learning_rate: float = 1e-4 + learning_rate_schedule_steps: int = 1000 + warmup_steps_fraction: float = 0.1 + cosine_learning_rate_final_fraction: float = 0.0 + steps: int = 100 + lr_schedule_type: str = "cosine" + use_iota_embed: bool = False + + +class TestValidateTrainConfig(unittest.TestCase): + """Tests for validate_train_config.""" + + def test_valid_config_passes(self): + """Verifies no exception raised for a valid config.""" + config = MockConfig() + # Should not raise + validate_train_config(config) + + def test_missing_run_name_raises(self): + """Verifies AssertionError when run_name is empty.""" + config = MockConfig(run_name="") + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_zero_steps_raises(self): + """Verifies AssertionError when steps is 0.""" + config = MockConfig(steps=0) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_negative_steps_raises(self): + """Verifies AssertionError when steps is negative.""" + config = MockConfig(steps=-5) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_fp8_with_grad_accumulation_raises(self): + """Verifies AssertionError for fp8 quantization + gradient_accumulation_steps > 1.""" + config = MockConfig(quantization="fp8", gradient_accumulation_steps=2) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_nanoo_fp8_with_grad_accumulation_raises(self): + """Verifies AssertionError for nanoo_fp8 quantization + gradient_accumulation_steps > 1.""" + config = MockConfig(quantization="nanoo_fp8", gradient_accumulation_steps=4) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_fp8_with_single_grad_accumulation_passes(self): + """Verifies no error for fp8 with gradient_accumulation_steps=1.""" + config = MockConfig(quantization="fp8", gradient_accumulation_steps=1) + validate_train_config(config) # Should not raise + + def test_packing_with_synthetic_data_logs_warning(self): + """Verifies no exception for packing + synthetic (just logs a warning).""" + config = MockConfig(packing=True, dataset_type="synthetic") + # Should not raise - just log a warning + validate_train_config(config) + + def test_local_dataset_path_logs_warning(self): + """Verifies no exception for local dataset_path (just logs a warning).""" + config = MockConfig(dataset_path="/local/path/to/data") + validate_train_config(config) # Should not raise + + def test_local_output_directory_logs_warning(self): + """Verifies no exception for local base_output_directory (just logs a warning).""" + config = MockConfig(base_output_directory="/local/output") + validate_train_config(config) # Should not raise + + +class TestCreateTrainingOptimizer(unittest.TestCase): + """Tests for create_training_optimizer.""" + + def _make_config(self, opt_type="adamw", **kwargs): + """Creates a mock config for optimizer tests.""" + cfg = MockConfig(opt_type=opt_type, **kwargs) + return cfg + + def _mock_lr_schedule(self): + """Returns a mock learning rate schedule that returns a fixed value.""" + return lambda step: 1e-4 + + def test_adamw_optimizer_returns_schedule_and_tx(self): + """Verifies create_training_optimizer returns a schedule and optax transform for adamw.""" + config = MagicMock() + config.opt_type = "adamw" + config.adam_b1 = 0.9 + config.adam_b2 = 0.999 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.01 + config.mu_dtype = None + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + schedule, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(schedule) + self.assertIsNotNone(tx) + # Verify it's an optax GradientTransformation + self.assertTrue(hasattr(tx, "init")) + self.assertTrue(hasattr(tx, "update")) + + def test_adam_pax_optimizer_returns_tx(self): + """Verifies create_training_optimizer works for adam_pax optimizer.""" + config = MagicMock() + config.opt_type = "adam_pax" + config.adam_b1 = 0.9 + config.adam_b2 = 0.999 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.01 + config.mu_dtype = None + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + _, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(tx) + self.assertTrue(hasattr(tx, "init")) + self.assertTrue(hasattr(tx, "update")) + + def test_sgd_optimizer_returns_tx(self): + """Verifies create_training_optimizer works for sgd optimizer.""" + config = MagicMock() + config.opt_type = "sgd" + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.0 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + _, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(tx) + self.assertTrue(hasattr(tx, "init")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index c176e53883..c4694d9460 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -37,6 +37,7 @@ """Check if the logits generated by a model's src/MaxText/HF implementation matches golden logits for the same inputs""" import argparse +import functools import os from pathlib import Path import sys @@ -242,8 +243,13 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, config, False, rng1) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) if test_args.golden_logits_path == "": input_golden_data_path = os.path.join( @@ -424,8 +430,13 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, rng1) + maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) prompts = ["I love to", "Today is a", "What is the"] all_data_to_save = [] diff --git a/tools/gcs_benchmarks/standalone_checkpointer.py b/tools/gcs_benchmarks/standalone_checkpointer.py index 6240c10cc0..9f39cc529f 100644 --- a/tools/gcs_benchmarks/standalone_checkpointer.py +++ b/tools/gcs_benchmarks/standalone_checkpointer.py @@ -19,6 +19,7 @@ # See github.com/google/maxtext/issues/20 for more import datetime +from functools import partial import os from typing import Sequence @@ -51,15 +52,21 @@ def checkpoint_loop(config, state=None): Returns: """ - model = from_config(config) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = from_config(config) mesh = model.mesh - init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools( - config, model, mesh - ) - - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, tx, config, init_rng, mesh, is_training=True - ) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + _, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) + + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() @@ -82,30 +89,24 @@ def checkpoint_loop(config, state=None): if state is not None: # Checkpoint was available for restore if jax.process_index() == 0: max_logging.log( - "STANDALONE CHECKPOINTER : Checkpoint restored in :" - f" {checkpoint_load_end - checkpoint_load_start}" + "STANDALONE CHECKPOINTER : Checkpoint restored in :" f" {checkpoint_load_end - checkpoint_load_start}" ) else: # Checkpoint was unavailable, state needs to be initialized - state, _, _, _ = maxtext_utils.setup_training_state( - model, None, tx, config, init_rng, mesh, checkpoint_manager - ) + state, _, _, _ = maxtext_utils.setup_training_state(None, config, mesh, checkpoint_manager, init_state_fn) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() # A barrier to sync all hosts before starting to save checkpoint - jax.experimental.multihost_utils.sync_global_devices( - "Barrier before save" - ) + jax.experimental.multihost_utils.sync_global_devices("Barrier before save") if checkpointing.save_checkpoint(checkpoint_manager, int(step), state): checkpoint_manager.wait_until_finished() end_time = datetime.datetime.now() if jax.process_index() == 0: max_logging.log( - "STANDALONE CHECKPOINTER : Checkpoint saved in" - f" {end_time - start_time} ,step {step}, on host 0" + "STANDALONE CHECKPOINTER : Checkpoint saved in" f" {end_time - start_time} ,step {step}, on host 0" ) return state @@ -123,12 +124,8 @@ def add_entropy_to_checkpoint(state): state: Returns state with entropy added to the optimizer state. """ opt_0 = state.opt_state[0] - opt_0 = opt_0._replace( - mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params) - ) - opt_0 = opt_0._replace( - nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params) - ) + opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params)) + opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params)) new_opt = [opt_0] + list(state.opt_state[1:]) state = state.replace(opt_state=new_opt) return state diff --git a/tools/gcs_benchmarks/standalone_dataloader.py b/tools/gcs_benchmarks/standalone_dataloader.py index 9766349aac..54177e9528 100644 --- a/tools/gcs_benchmarks/standalone_dataloader.py +++ b/tools/gcs_benchmarks/standalone_dataloader.py @@ -38,13 +38,13 @@ def data_load_loop(config, state=None): """Main data loader loop. Loads batches of data for each training step. """ - _, _, _, _, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) + _, _, _, model, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) data_loader = DataLoader(config, mesh, data_iterator, None) example_batch = None start = datetime.datetime.now() - start_step = get_first_step(state) + start_step = get_first_step(model, state) example_batch = data_loader.load_next_batch() jax.block_until_ready(example_batch) first_end = datetime.datetime.now()