Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# This conversion script reads paxml-format weights and emits a Linen-format
# MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`,
# `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen
# tree shape). Use the Linen path regardless of pure_nnx.
quant = quantizations.configure_quantization(cfg)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,13 @@ logical_axis_rules: [
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
# ==========================================
# Pipeline Parallelism
# ==========================================
['layers_outside_pipeline', []],
['layers_per_stage', []],
['num_activations', []],
['circular_repeats', []],
# ==========================================
# Deprecated / Scheduled for Removal
# ==========================================
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2902,8 +2902,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
# Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py.
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
Expand Down
37 changes: 16 additions & 21 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,29 +542,28 @@ def setup_train_loop(
- eval_data_iterator: The iterator for the evaluation dataset (or None).
- state: The initialized training state.
"""
# GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below,
# state_mesh_shardings.params, and the inference path through MaxEngine which is
# Linen-only). Run on Linen path regardless of pure_nnx; warn the user since
# NNX-format checkpoints will mismatch at restore time.
if config.pure_nnx or config_inference.pure_nnx:
max_logging.log(
"WARNING: GRPO RL trainer does not yet support pure_nnx natively; "
"running on the Linen path. NNX-format checkpoints will not load correctly here."
)
with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT):
max_logging.log("Training mesh used for the workload")
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
training_devices = jax.devices()[num_inference_devices:]
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = mt.from_config(config, devices=training_devices)
model = mt.from_config(config, devices=training_devices)
mesh = model.mesh
max_logging.log("Inference mesh used for the workload")
inference_devices = jax.devices()[:num_inference_devices]
if config_inference.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_mesh = inference_model.mesh
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)

with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
Expand All @@ -573,14 +572,10 @@ def setup_train_loop(
data_iterator, config, mesh, checkpoint_manager, init_state_fn
)

# create inference_state_mesh_shardings from inference_mesh
if config_inference.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
# create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above)
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
config_inference, inference_mesh, init_inference_state_fn, is_training=False
)[2]
Expand Down
22 changes: 7 additions & 15 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def __init__(self, config: Any, devices: Any | None = None):
devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices)
self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

# Model and Optimizer definition
# Model and Optimizer definition.
# MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params,
# state.opt_state) and serves Linen-format inference checkpoints. Use Linen path
# regardless of pure_nnx — the flag affects training, not inference serving.
quant = quantizations.configure_quantization(config)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))

self.abstract_params = None
Expand Down Expand Up @@ -232,11 +232,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
rng1, rng2, rng3 = jax.random.split(rng, 3)
if params:
print("Resharding given params")
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
self.config, self._mesh, init_state_fn, False
)
Expand All @@ -245,11 +241,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
state = maxtext_utils.init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
else:
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,14 @@ def __init__(
elif self.is_qwen3_hybrid:
self.query_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
)
self.key_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,9 +2250,9 @@ def __call__(
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)

# Only apply per expert scales if we have not fused with the out-projections at init time.
if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales:
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
# Only apply per expert scales if we have not fused with the out-projections at init time.
if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales:
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]

if self.wi_0_sparsity_module is not None:
_, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel)
Expand Down
30 changes: 27 additions & 3 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,16 @@ def pure_layer_fn(state_in, y_in):
out = merged_layer(y_in, **kwargs)
return out, nnx.state(merged_layer)

checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
# mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
# but the Linen scope retains JAX tracers from the first trace, causing
# UnexpectedTracerError. Skip checkpoint for these quantization types.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
out, new_state = pure_layer_fn(state, y)
else:
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
nnx.update(layer, new_state)

return out
Expand Down Expand Up @@ -667,7 +675,23 @@ def layer_fn(carry, scanned_vars):
params = nnx_ensure_scan_leading_axis(params, length)
state = nnx_ensure_scan_leading_axis(state, length)

final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
# mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
# intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
# causing UnexpectedTracerError. Use a Python for loop instead for these types.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
carry = x_in
per_layer_states = []
for i in range(length):
current_params = jax.tree.map(lambda x, i=i: x[i], params)
current_state = jax.tree.map(lambda x, i=i: x[i], state)
carry, new_state_i = layer_fn(carry, (current_params, current_state))
per_layer_states.append(new_state_i)
final_carry = carry
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
else:
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
Expand Down
35 changes: 35 additions & 0 deletions src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flax.core import FrozenDict
from flax.core import meta
from flax.nnx import graph
from flax.nnx import tracers as nnx_tracers
from flax.nnx import variablelib
from flax.nnx.bridge import module as bdg_module
from flax.nnx.module import Module
Expand Down Expand Up @@ -167,6 +168,39 @@ def current_linen_module() -> linen.Module | None:
return None


def is_linen_initializing() -> bool:
"""Check if the current execution context is inside a Linen init() call.

Returns True when called from within a ``to_linen_class`` wrapper's
``init()`` path. Uses :func:`current_linen_module` to access the Linen
module stack (private API already used by this module).

This is used by NNX pipeline modules to short-circuit the full scan
during Linen init, where only the output shape/dtype is needed.
"""
module = current_linen_module()
if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing):
return module.is_initializing()
return False


def _refresh_variable_trace_state(module: Module) -> None:
"""Refresh _trace_state for Variables that have stale trace state.

When nnx.update() is called with tracer values from a JAX transformation
(e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which
updates the raw value but not _trace_state. This leaves Variables with a
stale _trace_state from the outer (Python) context, causing nnx.split() to
fail with "Cannot extract graph node from different trace level" errors.

This function resets _trace_state on any Variables whose _can_update is False
so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed.
"""
for _, v in nnx.graph.iter_graph(module):
if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access
object.__setattr__(v, "_trace_state", nnx_tracers.TraceState())


class ToNNX(Module):
"""A wrapper to turn any Linen module into an NNX module.

Expand Down Expand Up @@ -476,6 +510,7 @@ def maybe_unbox(x):
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")

nnx.update(module, new_state)
_refresh_variable_trace_state(module)

_fix_for_qwix_quantization(module)
method_fn = _get_module_method(module, nnx_method)
Expand Down
14 changes: 12 additions & 2 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
return y_flat.reshape(input_shape)


def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
def Qwen3NextRMSNorm(
num_features: int,
epsilon: float = 1e-6,
dtype: DType = None,
weight_dtype: DType = None,
shard_mode=None,
kernel_axes=None,
parameter_memory_host_offload=None,
*,
rngs: nnx.Rngs,
):
"""
Used for input and post attention layernorms
in Qwen3NextDecoderLayer.
Expand All @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
return nnx.data(
RMSNorm(
num_features=num_features,
epsilon=eps,
epsilon=epsilon,
dtype=dtype,
weight_dtype=weight_dtype,
scale_init=linen_initializers.zeros,
Expand Down
24 changes: 19 additions & 5 deletions src/maxtext/layers/train_state_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" The NNX Unified TrainState. """
"""The NNX Unified TrainState."""

from typing import Any

Expand All @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module):
This replaces Linen's TrainState for checkpointing.

Linen TrainState pytree:
{params: {...}, opt_state: {}...}
{"params": {...}, "opt_state": {}...}
TrainStateNNX state pytree:
{“model”: {...}, “optimizer”: {“opt_state”: {...}}
{"model": {...}, "optimizer": {"opt_state": {...}}}

For DPO (Direct Preference Optimization), an optional `reference_model`
carries a frozen copy of the same architecture used to compute reference
log-probabilities. Only `model` is updated by `apply_gradients`; the
reference is held alongside so it is sharded, jit-traced, and checkpointed
with the rest of the train state.
"""

def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None):
def __init__(
self,
model: nnx.Module,
optimizer: nnx.Optimizer | None,
reference_model: nnx.Module | None = None,
):
self.model = model
self.optimizer = optimizer
if reference_model is not None:
self.reference_model = reference_model

def apply_gradients(self, grads: Any):
"""
Mimics the Linen apply_gradients function.
Updates the optimizer state, applies updates to parameters,
and increments the step counter.
and increments the step counter. Only updates `self.model`;
`self.reference_model` (if present) is left untouched.
"""
if self.optimizer is None:
raise RuntimeError(
Expand Down
Loading
Loading