Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import argparse
import functools
import gc
import os
import sys
Expand Down Expand Up @@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
mesh = Mesh(devices_array, cfg.mesh_axes)

quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")

Expand Down
5 changes: 3 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,9 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
enable_nnx: False
pure_nnx_decoder: False
pure_nnx: False

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ class HardwareAndMesh(BaseModel):
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")


class LayoutAndSharding(BaseModel):
Expand Down
32 changes: 26 additions & 6 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,23 +546,43 @@ def setup_train_loop(
max_logging.log("Training mesh used for the workload")
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
training_devices = jax.devices()[num_inference_devices:]
model = mt.from_config(config, devices=training_devices)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = mt.from_config(config, devices=training_devices)
mesh = model.mesh
max_logging.log("Inference mesh used for the workload")
inference_devices = jax.devices()[:num_inference_devices]
inference_model = mt.from_config(config_inference, devices=inference_devices)
if config_inference.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_mesh = inference_model.mesh
init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh)
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)

with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh)
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
data_iterator, config, mesh, checkpoint_manager, init_state_fn
)

# create inference_state_mesh_shardings from inference_mesh
if config_inference.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False
config_inference, inference_mesh, init_inference_state_fn, is_training=False
)[2]
if not config.using_pipeline_parallelism:
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
Expand Down Expand Up @@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None):
data_buffer = []
data_buffer_lock = threading.Lock()

start_step = get_first_step(state) # this is the start_step for training
start_step = get_first_step(model, state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
inference_prof = profiler.Profiler(config_inference, offset_step=start_step)
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)
Expand Down
21 changes: 16 additions & 5 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None):

# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))

self.abstract_params = None
Expand Down Expand Up @@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
rng1, rng2, rng3 = jax.random.split(rng, 3)
if params:
print("Resharding given params")
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
self.model, None, self.config, rng, self._mesh, False
self.config, self._mesh, init_state_fn, False
)
# reshard given params based on shardings from config in MaxEngine
params = jax.device_put(params, state_mesh_shardings.params)
state = maxtext_utils.init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
else:
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(
self.model, self.config, rng1, self._mesh, None
)
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
Expand Down
48 changes: 48 additions & 0 deletions src/maxtext/layers/train_state_nnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" The NNX Unified TrainState. """

from typing import Any

from flax import nnx


class TrainStateNNX(nnx.Module):
"""
A unified container for NNX models and optimizers.
This replaces Linen's TrainState for checkpointing.

Linen TrainState pytree:
{“params”: {...}, “opt_state”: {}...}
TrainStateNNX state pytree:
{“model”: {...}, “optimizer”: {“opt_state”: {...}}
"""

def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None):
self.model = model
self.optimizer = optimizer

def apply_gradients(self, grads: Any):
"""
Mimics the Linen apply_gradients function.
Updates the optimizer state, applies updates to parameters,
and increments the step counter.
"""
if self.optimizer is None:
raise RuntimeError(
"Cannot call apply_gradients on a TrainStateNNX initialized without an optimizer. "
"This usually happens when the state was created for inference only."
)
self.optimizer.update(self.model, grads)
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def train_loop(config, recorder, state=None):
compiled_stats = compiled.memory_analysis()
max_utils.print_compiled_memory_stats(compiled_stats)

start_step = get_first_step(state) # this is the start_step for training
start_step = get_first_step(model, state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
data_loader = DataLoader(config, mesh, data_iterator, recorder)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
Expand Down
8 changes: 5 additions & 3 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@
VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules()


def get_first_step(state):
return int(state.step)
def get_first_step(model, state):
if isinstance(model, nn.Module):
return int(state.step)
return int(state.optimizer.step.get_value())


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -512,7 +514,7 @@ def train_loop(config, recorder, state=None):
compiled_stats = compiled.memory_analysis()
max_utils.print_compiled_memory_stats(compiled_stats)

start_step = get_first_step(state) # this is the start_step for training
start_step = get_first_step(model, state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

Expand Down
64 changes: 50 additions & 14 deletions src/maxtext/trainers/pre_train/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Sequence

from absl import app
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
from jax.experimental.serialize_executable import serialize
Expand All @@ -36,6 +37,7 @@
from maxtext.configs import pyconfig
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
from maxtext.layers import quantizations
from maxtext.layers import train_state_nnx
from maxtext.models import models
from maxtext.optimizers import optimizers
from maxtext.trainers.diloco import diloco
Expand All @@ -44,6 +46,8 @@
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import sharding
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils

# pylint: disable=too-many-positional-arguments

Expand Down Expand Up @@ -93,7 +97,10 @@ def get_shaped_inputs(topology_mesh, config):
"""Get shaped abstractions of inputs to train_step: state, batch and rng"""
# Construct the model and optimizer to get shaped versions of the state
quant = quantizations.configure_quantization(config)
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
if config.pure_nnx:
_create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh)
else:
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
# The learning_rate_schedule is baked into the compiled object.
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
# pass in model for muon
Expand All @@ -103,18 +110,39 @@ def get_shaped_inputs(topology_mesh, config):
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)
shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype)

# Shaped state
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
model, tx, config, example_rng, topology_mesh
)
if config.pure_nnx:

def create_train_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

init_state_fn = create_train_state_fn
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng)

# unsharded logical annotations
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh)
# Shaped state
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True)

if config.pure_nnx:
# NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings.
logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings)
# For NNX, get_functional_train_with_signature expects the graphdef (static structure),
# not the raw model — mirroring how the training loop does nnx.split(train_state).
with nn_partitioning.axis_rules(config.logical_axis_rules):
graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh)
model = graphdef
else:
# unsharded logical annotations
logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)

# Shaped batch
shaped_batch = maxtext_utils.get_shaped_batch(config)

shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
if config.pure_nnx:
shaped_train_args = (abstract_state, shaped_batch)
else:
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
shaped_train_kwargs = {}
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model

Expand Down Expand Up @@ -277,12 +305,20 @@ def main(argv: Sequence[str]) -> None:
# print weights sharding info under debug sharding mode
if config.debug_sharding:
max_utils.print_non_trivial_mesh_axis(topology_mesh)
maxtext_utils.print_shardings_params(
shaped_train_args[0].params,
state_mesh_shardings.params,
topology_mesh,
logical_annotations.params,
)
if config.pure_nnx:
maxtext_utils.print_shardings_params(
shaped_train_args[0],
state_mesh_shardings,
topology_mesh,
logical_annotations,
)
else:
maxtext_utils.print_shardings_params(
shaped_train_args[0].params,
state_mesh_shardings.params,
topology_mesh,
logical_annotations.params,
)

# Compile
print("Jitting and compiling train step...", flush=True)
Expand Down
20 changes: 15 additions & 5 deletions src/maxtext/utils/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16.
"""

import functools
import os.path
from typing import Sequence

Expand All @@ -42,8 +43,6 @@
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils

Transformer = models.transformer_as_linen


def _possibly_unroll_params(config, training_state, training_state_annotations, mesh):
"""Unroll scanned input layers when force_unroll is set."""
Expand Down Expand Up @@ -93,12 +92,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
"""Read training checkpoint at path defined by load_full_state_path."""
# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
rng = random.PRNGKey(0)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng)
state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state(
model, None, tx, config, rng, mesh, checkpoint_manager
None, config, mesh, checkpoint_manager, init_state_fn
)
num_params = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion")
Expand All @@ -109,7 +116,10 @@ def _generate_lora_decode_checkpoints(config, mesh):
"""Read lora checkpoints checkpoint at path defined by load_full_state_path."""
# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
rng = random.PRNGKey(0)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
Expand Down
Loading
Loading