Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# Output is Linen-format (keystr_map below uses Linen tree paths). Route to
# Linen regardless of pure_nnx.
quant = quantizations.configure_quantization(cfg)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -102,11 +101,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
Comment thread
ecnal-cienet marked this conversation as resolved.
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
Expand Down
2 changes: 0 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2959,8 +2959,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
Expand Down
35 changes: 14 additions & 21 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,29 +537,26 @@ def setup_train_loop(
- eval_data_iterator: The iterator for the evaluation dataset (or None).
- state: The initialized training state.
"""
# GRPO is Linen-shaped end-to-end (inference goes through Linen MaxEngine).
# Route to Linen regardless of pure_nnx; warn since NNX checkpoints won't load.
if config.pure_nnx or config_inference.pure_nnx:
max_logging.log(
"WARNING: GRPO RL trainer does not yet support pure_nnx natively; "
"running on the Linen path. NNX-format checkpoints will not load correctly here."
)
with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT):
max_logging.log("Training mesh used for the workload")
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
training_devices = jax.devices()[num_inference_devices:]
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = mt.from_config(config, devices=training_devices)
model = mt.from_config(config, devices=training_devices)
mesh = model.mesh
max_logging.log("Inference mesh used for the workload")
inference_devices = jax.devices()[:num_inference_devices]
if config_inference.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_mesh = inference_model.mesh
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)

with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
Expand All @@ -568,14 +565,10 @@ def setup_train_loop(
data_iterator, config, mesh, checkpoint_manager, init_state_fn
)

# create inference_state_mesh_shardings from inference_mesh
if config_inference.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
# create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above)
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
config_inference, inference_mesh, init_inference_state_fn, is_training=False
)[2]
Expand Down
20 changes: 5 additions & 15 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,10 @@ def __init__(self, config: Any, devices: Any | None = None):
devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices)
self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

# Model and Optimizer definition
# MaxEngine serves Linen-format inference checkpoints; the surface stays
# Linen-shaped via transformer_as_linen regardless of pure_nnx.
quant = quantizations.configure_quantization(config)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))

self.abstract_params = None
Expand Down Expand Up @@ -232,11 +230,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
rng1, rng2, rng3 = jax.random.split(rng, 3)
if params:
print("Resharding given params")
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
self.config, self._mesh, init_state_fn, False
)
Expand All @@ -245,11 +239,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
state = maxtext_utils.init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
else:
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
Expand Down
27 changes: 24 additions & 3 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,14 @@ def pure_layer_fn(state_in, y_in):
out = merged_layer(y_in, **kwargs)
return out, nnx.state(merged_layer)

checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint
# re-traces and hits UnexpectedTracerError. Skip remat for FP8.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
out, new_state = pure_layer_fn(state, y)
else:
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
nnx.update(layer, new_state)

return out
Expand Down Expand Up @@ -667,7 +673,22 @@ def layer_fn(carry, scanned_vars):
params = nnx_ensure_scan_leading_axis(params, length)
state = nnx_ensure_scan_leading_axis(state, length)

final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
# for FP8 instead.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
carry = x_in
per_layer_states = []
for i in range(length):
current_params = jax.tree.map(lambda x, i=i: x[i], params)
current_state = jax.tree.map(lambda x, i=i: x[i], state)
carry, new_state_i = layer_fn(carry, (current_params, current_state))
per_layer_states.append(new_state_i)
final_carry = carry
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
else:
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
Expand Down
27 changes: 27 additions & 0 deletions src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flax.core import FrozenDict
from flax.core import meta
from flax.nnx import graph
from flax.nnx import tracers as nnx_tracers
from flax.nnx import variablelib
from flax.nnx.bridge import module as bdg_module
from flax.nnx.module import Module
Expand Down Expand Up @@ -167,6 +168,31 @@ def current_linen_module() -> linen.Module | None:
return None


def is_linen_initializing() -> bool:
"""Returns True if currently inside a Linen ``init()`` call.

Used by NNX pipeline modules to short-circuit the scan during init,
where only the output shape/dtype is needed.
"""
module = current_linen_module()
if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing):
return module.is_initializing()
return False


def _refresh_variable_trace_state(module: Module) -> None:
"""Resets stale ``_trace_state`` on Variables to unblock downstream ``nnx.split``.

``nnx.update`` called with JAX tracer values uses ``_unsafe_bypass_check=True``,
which leaves Variables with a stale ``_trace_state`` from the outer Python
context and breaks ``nnx.split`` with "Cannot extract graph node from different
trace level". Resets ``_trace_state`` on any Variable whose ``_can_update`` is False.
"""
for _, v in nnx.graph.iter_graph(module):
if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access
object.__setattr__(v, "_trace_state", nnx_tracers.TraceState())


class ToNNX(Module):
"""A wrapper to turn any Linen module into an NNX module.

Expand Down Expand Up @@ -476,6 +502,7 @@ def maybe_unbox(x):
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")

nnx.update(module, new_state)
_refresh_variable_trace_state(module)

_fix_for_qwix_quantization(module)
method_fn = _get_module_method(module, nnx_method)
Expand Down
12 changes: 11 additions & 1 deletion src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
return y_flat.reshape(input_shape)


def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
def Qwen3NextRMSNorm(
num_features: int,
eps: float = 1e-6,
dtype: DType = None,
weight_dtype: DType = None,
shard_mode=None,
kernel_axes=None,
parameter_memory_host_offload=None,
*,
rngs: nnx.Rngs,
):
"""
Used for input and post attention layernorms
in Qwen3NextDecoderLayer.
Expand Down
5 changes: 4 additions & 1 deletion src/maxtext/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from maxtext.common.common_types import AttentionType, Config
from maxtext.layers import attentions
from maxtext.layers import initializers
from maxtext.layers import linears
from maxtext.layers import moe
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
Expand Down Expand Up @@ -132,6 +133,8 @@ def __init__(
rngs=rngs,
)

self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)

def __call__(
self,
inputs,
Expand Down Expand Up @@ -189,7 +192,7 @@ def __call__(
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
layer_output = self.dropout(layer_output, deterministic=deterministic)

layer_output = nn.with_logical_constraint(
layer_output,
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
shard_mode=config.shard_mode,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
parameter_memory_host_offload=config.parameter_memory_host_offload,
rngs=rngs,
)

Expand Down
11 changes: 10 additions & 1 deletion src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def setup(self):
rngs=self.make_rng("mtp_block"),
)

def logits_from_hidden_states(self, hidden_states, deterministic, model_mode):
def logits_from_hidden_states_for_vocab_tiling(self, hidden_states, deterministic, model_mode):
"""
Compute logits from hidden states (wrapping decoder.apply_output_head).
This function is only used for vocabulary tiling.
Expand Down Expand Up @@ -398,6 +398,15 @@ def no_op(self, *args, **kwargs):
"""A no-op method to allow the model to be used in a lazy context."""
return

def logits_from_hidden_states_for_vocab_tiling(self, hidden_states, deterministic, model_mode):
"""Computes logits from hidden states; used by vocabulary tiling."""
return self.decoder.apply_output_head(
shared_embedding=self.token_embedder,
y=hidden_states,
deterministic=deterministic,
model_mode=model_mode,
)

def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32):
"""Initializes the KV cache for the Transformer.

Expand Down
4 changes: 3 additions & 1 deletion src/maxtext/models/olmo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from maxtext.common.common_types import AttentionType, Config
from maxtext.layers import attentions
from maxtext.layers import initializers
from maxtext.layers import linears
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
model_mode=model_mode,
rngs=rngs,
)
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)

def __call__(
self,
Expand Down Expand Up @@ -202,7 +204,7 @@ def __call__(
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
layer_output = self.dropout(layer_output, deterministic=deterministic)

layer_output = nn.with_logical_constraint(
layer_output,
Expand Down
Loading
Loading