From 7955a4466bfd0f67e88025052bfe66505012b2fb Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Mon, 16 Mar 2026 03:03:57 +0000 Subject: [PATCH 1/5] Test --- src/maxtext/configs/base.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 457e90606d..c82f00a5e6 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1110,7 +1110,7 @@ position_id_per_seconds: 25 # Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium. subslice_shape: "" -# NNX +# NNX tests enable_nnx: false pure_nnx_decoder: false From 4eae7868ad2bda642f9f53fd25fd31b46f6bea99 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Wed, 25 Mar 2026 12:19:13 +0800 Subject: [PATCH 2/5] fix --- src/maxtext/configs/base.yml | 6 +- src/maxtext/layers/attentions.py | 4 +- src/maxtext/layers/initializers.py | 13 + src/maxtext/layers/nnx_decoders.py | 235 +++++--- src/maxtext/layers/normalizations.py | 17 +- src/maxtext/models/qwen3.py | 4 +- ...{max_utils_test.py => a_max_utils_test.py} | 0 tests/unit/nnx_decoder_test.py | 536 ++++++++++++++++++ tests/unit/train_compile_test.py | 4 +- 9 files changed, 737 insertions(+), 82 deletions(-) rename tests/unit/{max_utils_test.py => a_max_utils_test.py} (100%) create mode 100644 tests/unit/nnx_decoder_test.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index c82f00a5e6..9a290c805b 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1110,9 +1110,9 @@ position_id_per_seconds: 25 # Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium. subslice_shape: "" -# NNX tests -enable_nnx: false -pure_nnx_decoder: false +# NNX +enable_nnx: True +pure_nnx_decoder: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 813cb33014..824b7590eb 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -534,14 +534,14 @@ def __init__( elif self.is_qwen3_next: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/initializers.py b/src/maxtext/layers/initializers.py index 20baf9a633..a38a1d5efe 100644 --- a/src/maxtext/layers/initializers.py +++ b/src/maxtext/layers/initializers.py @@ -94,6 +94,19 @@ def variable_to_logically_partitioned(variable: nnx.VariableState): out_sharding = metadata["sharding"] if out_sharding is not None: + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Only nnx.Param variables are typically scanned across the param_scan_axis + scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0 + + # Check if the scan axis name was already inserted into out_sharding + # (e.g., by _create_scanned_layers). Skip insertion to avoid duplicates. + sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + if partition_name not in sharding_list: + sharding_list.insert(scan_axis, partition_name) + + out_sharding = tuple(sharding_list) + return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args] variable.value, out_sharding, # type: ignore[arg-type] diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..647d6611a6 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -307,11 +307,10 @@ def __init__( dense_cls, moe_cls = decoder_block_classes num_dense = config.first_num_dense_layers - self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) - + self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, metadata_axis_name="dense_layers", rngs=rngs) num_moe = config.num_decoder_layers - config.first_num_dense_layers - - self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) + self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs) + elif self.is_gemma3: attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) scan_length = config.num_decoder_layers // attention_pattern_length @@ -323,7 +322,9 @@ def __init__( RemattedGemma3Block = gemma3.Gemma3ScannableBlock if scan_length > 0: - self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) + self.layers = self._create_scanned_layers( + RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) self.layers_remainder = RemattedGemma3Block( config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs ) # pytype: disable=wrong-keyword-args @@ -337,7 +338,13 @@ def __init__( "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) + if num_layers > 0: + self.layers = self._create_scanned_layers( + layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + else: + self.layers = nnx.List([]) + else: self.layers = nnx.List([]) @@ -386,34 +393,80 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): ) return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) - def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): - """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" + def _create_scanned_layers(self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs): + """Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization. - def create_layer_fn(rng): - layer = decoder_layer_class( - config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs - ) - - return layer + Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization. + With vmap, all layers' parameters are created simultaneously (O(N) peak memory). + With scan, parameters are created one layer at a time (O(1) peak intermediate memory), + which prevents OOM on memory-constrained devices like TPU v6e-4. + """ + scan_axis = self.config.param_scan_axis - # Workaround for Deepseek MTP test failure. - # TODO: Handle this properly. + # Fork rngs to get per-layer RNG states for scanning try: forked_rngs = rngs.fork(split=length) - except: # pylint: disable=bare-except pass - out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) - layers_vmapped = nnx.vmap( - create_layer_fn, - in_axes=0, - out_axes=out_axes, - axis_name="layers", - transform_metadata={nnx.PARTITION_NAME: "layers"}, - )(forked_rngs) + rngs_graphdef, rngs_state = nnx.split(forked_rngs) + + # Create a reference layer to capture the module graph structure (graphdef). + # This layer's params are discarded — only the structure is kept. + # Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the + # graphdef has the same number of RNG state leaves as the scan-created layers. + first_rng_state = jax.tree.map(lambda x: x[0], rngs_state) + ref_rngs = nnx.merge(rngs_graphdef, first_rng_state) + ref_layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, + model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs + ) + layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...) + del ref_layer + + # Sequentially create each layer's parameters via jax.lax.scan. + # The scan body is traced once; XLA executes it N times with different RNG keys, + # keeping only one layer's intermediate state alive at a time. + def scan_body(carry, rng_state_slice): + layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice) + layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, + model_mode=self.model_mode, rngs=layer_rngs, **layer_kwargs + ) + _, params, rest = nnx.split(layer, nnx.Param, ...) + return carry, (params, rest) + + _, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state) - return layers_vmapped + # jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis. + if scan_axis != 0: + stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params) + + # Add partition metadata that nnx.vmap's transform_metadata would normally set. + # This metadata is read by variable_to_logically_partitioned() in initializers.py + # and by nnx.get_partition_spec() (via the updated out_sharding) to produce + # correct sharding specs that include the scan axis dimension. + def _add_scan_metadata(state, axis): + def _update_leaf(leaf): + if isinstance(leaf, nnx.VariableState): + metadata = leaf.get_metadata() + metadata[nnx.PARTITION_NAME] = metadata_axis_name + metadata["param_scan_axis"] = axis + # Insert the scan axis name into out_sharding so that + # nnx.get_partition_spec returns specs matching the actual tensor rank. + # Without this, scanned params are 3D but specs remain 2D. + if "out_sharding" in metadata and metadata["out_sharding"]: + sharding = list(metadata["out_sharding"]) + sharding.insert(axis, metadata_axis_name) + metadata["out_sharding"] = tuple(sharding) + return leaf.replace(**metadata) + return leaf + return jax.tree.map(_update_leaf, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + + stacked_params = _add_scan_metadata(stacked_params, scan_axis) + stacked_rest = _add_scan_metadata(stacked_rest, 0) + + return nnx.merge(layer_graphdef, stacked_params, stacked_rest) def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" @@ -435,53 +488,47 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs) """Runs the layer stack using nnx.scan.""" policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) - graphdef, params, state = nnx.split( - layers, nnx.Param, ... - ) # state: the mutable state we carry (KV cache, RNGs, etc.) + graphdef, params, state = nnx.split(layers, nnx.Param, ...) scan_axis = self.config.param_scan_axis if scan_axis != 0: - # Move scan_axis to 0 so scan can iterate over it params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) layer_cls = layers.__class__ sig = inspect.signature(layer_cls.__call__) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - layer_cls = layers.__class__ # Access the underlying class - sig = inspect.signature(layer_cls.__call__) - # Filter kwargs to only include keys that exist in the layer's signature - valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + def _extract_matching_state(template, full): + if isinstance(template, nnx.State): + return nnx.State({k: _extract_matching_state(v, full[k]) for k, v in template.items()}) + elif isinstance(template, dict): + return {k: _extract_matching_state(v, full[k]) for k, v in template.items()} + return full def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer current_params, current_state = scanned_vars if self.config.parameter_memory_host_offload: current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params) - # Merge using the SLICED state layer = nnx.merge(graphdef, current_params, current_state) - - # Run the layer (Filter kwargs if using the solution from previous turn) layer_out = layer(carry, *args, **valid_kwargs) - new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out - - # Extract the updated state to return it - # _, new_current_state = nnx.split(layer, nnx.Param, ...) - new_current_state = nnx.state(layer) + + new_full_state = nnx.state(layer) + new_current_state = _extract_matching_state(current_state, new_full_state) + + # ONLY return non-param state to prevent memory duplication of weights return new_carry, new_current_state layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state)) if scan_axis != 0: - scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) - scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) - scanned_state = nnx.State.merge(scanned_params, scanned_other) - + params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) + + scanned_state = nnx.State.merge(params, scanned_other) return final_carry, nnx.merge(graphdef, scanned_state) def get_decoder_layers(self): @@ -829,10 +876,19 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices): def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs): """Applies a single, unscanned Engram layer by dynamically slicing the NNX state.""" graphdef, state = nnx.split(layer_stack) + params, rest = state.split(nnx.Param, ...) + scan_axis = self.config.param_scan_axis + + # Helper to generate N-dimensional basic slices (e.g., x[:, idx, :]) + def _extract_slice(x, idx, axis): + slices = tuple(idx if i == axis else slice(None) for i in range(x.ndim)) + return x[slices] - # Slice the parameters for the current index (assuming scan axis is 0) - sliced_state = jax.tree.map(lambda x: x[current_idx], state) - single_layer = nnx.merge(graphdef, sliced_state) + # Slice using native indexing instead of jnp.take + sliced_params = jax.tree.map(lambda x: _extract_slice(x, current_idx, scan_axis), params) + sliced_rest = jax.tree.map(lambda x: _extract_slice(x, current_idx, 0), rest) + + single_layer = nnx.merge(graphdef, sliced_params, sliced_rest) # Run the single layer out = single_layer( @@ -841,14 +897,23 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg y = out[0] if isinstance(out, tuple) else out # Re-merge the updated state back into the specific slice of the stack - new_single_state = nnx.state(single_layer) - updated_state = jax.tree.map( + new_state = nnx.state(single_layer) + new_params, new_rest = new_state.split(nnx.Param, ...) + + updated_params = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim( + s, jnp.expand_dims(new_s, axis=scan_axis), current_idx, axis=scan_axis + ), + params, + new_params, + ) + updated_rest = jax.tree.map( lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0), - state, - new_single_state, + rest, + new_rest, ) - nnx.update(layer_stack, updated_state) + nnx.update(layer_stack, updated_params, updated_rest) return y def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs): @@ -856,10 +921,15 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args scan_length = next_boundary - current_idx if scan_length > 0: graphdef, state = nnx.split(layer_stack) + params, rest = state.split(nnx.Param, ...) + scan_axis = self.config.param_scan_axis - # Slice the chunk state - chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state) - chunk_stack = nnx.merge(graphdef, chunk_state) + # Slice the chunk state along the correct axes + chunk_params = jax.tree.map( + lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params + ) + chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest) + chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest) # Apply sequentially y, chunk_stack = self._apply_layers_sequentially( @@ -867,11 +937,17 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args ) # Update the original stack state - new_chunk_state = nnx.state(chunk_stack) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state + new_state = nnx.state(chunk_stack) + new_params, new_rest = new_state.split(nnx.Param, ...) + + updated_params = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params ) - nnx.update(layer_stack, updated_state) + updated_rest = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest + ) + + nnx.update(layer_stack, updated_params, updated_rest) return y @@ -961,7 +1037,7 @@ def __call__( y = self._apply_interleaved_scanned_layers( y, - self.moe_layer, + self.moe_layers, 0, (cfg.num_decoder_layers - cfg.first_num_dense_layers), [e - cfg.first_num_dense_layers for e in cfg.engram_layers], @@ -978,7 +1054,7 @@ def __call__( if cfg.use_batch_split_schedule: policy = self.get_remat_policy() - mock_params = self._build_linen_params(self.moe_layer) + mock_params = self._build_linen_params(self.moe_layers) y = deepseek_batchsplit.scan_batch_split_layers( y, @@ -992,8 +1068,8 @@ def __call__( policy=policy, ) else: - y, self.moe_layer = self._apply_layers_sequentially( - self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs + y, self.moe_layers = self._apply_layers_sequentially( + self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs ) elif self.is_gemma3: y = self._apply_gemma3_scanned_blocks( @@ -1009,7 +1085,8 @@ def __call__( ) else: scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + if scan_length > 0: + y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) else: prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) @@ -1027,7 +1104,16 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): for lyr, layer in enumerate(self.layers): graphdef, state = nnx.split(layer) - kv_cache = kv_caches[lyr] if kv_caches is not None else None + if kv_caches is not None: + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) + else: + kv_cache = None + else: + kv_cache = kv_caches[lyr] + else: + kv_cache = None input_tokens = decoder_input_tokens if cfg.engram_layers else None if input_tokens is not None: @@ -1037,7 +1123,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_caches["key_cache"][lyr] = kv_cache[0] + kv_caches["value_cache"][lyr] = kv_cache[1] + else: + kv_caches[lyr] = kv_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] @@ -1059,7 +1150,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) @@ -1124,7 +1215,7 @@ def decoder_as_linen( model_mode: str, quant: None | Quant = None, ): - """Creates a Decoder module.""" + """Creates a Decoder module""" module = nnx_wrappers.to_linen( NNXDecoder, config=config, diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 195d5bcc14..be6f56c8a4 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float, + dtype: DType, + weight_dtype: DType, + shard_mode: ShardMode = ShardMode.AUTO, + kernel_axes: tuple[None | str, ...] = (), + parameter_memory_host_offload: bool = False, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, + shard_mode=shard_mode, + kernel_axes=kernel_axes, scale_init=linen_initializers.zeros, + parameter_memory_host_offload=parameter_memory_host_offload, scale_offset=1.0, rngs=rngs, ) diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index eb15747fc2..5ba630adc3 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -962,7 +962,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -987,7 +987,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/tests/unit/max_utils_test.py b/tests/unit/a_max_utils_test.py similarity index 100% rename from tests/unit/max_utils_test.py rename to tests/unit/a_max_utils_test.py diff --git a/tests/unit/nnx_decoder_test.py b/tests/unit/nnx_decoder_test.py new file mode 100644 index 0000000000..bd507e720f --- /dev/null +++ b/tests/unit/nnx_decoder_test.py @@ -0,0 +1,536 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for nnx_decoders module. + +Tests cover: + - deepstack_process: pure-JAX helper for injecting visual embeddings + - NNXDecoderLayer: single transformer decoder layer (init + forward) + - NNXDecoder: decoder stack utilities (get_decoder_layers, get_norm_layer, + get_remat_policy, minimal_policy, and full forward pass) +""" + +import sys +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from flax import linen as nn +from flax import nnx +from jax.sharding import Mesh + +from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN, DecoderBlockType +from maxtext.configs import pyconfig +from maxtext.layers import linears +from maxtext.layers.attentions import Attention +from maxtext.layers.embeddings import Embed +from maxtext.layers.nnx_decoders import NNXDecoder, NNXDecoderLayer, deepstack_process +from maxtext.layers.normalizations import RMSNorm +from maxtext.models.gpt3 import Gpt3LayerNorm +from maxtext.models.llama2 import LlamaDecoderLayer +from maxtext.utils import maxtext_utils +from tests.utils.test_helpers import get_decoupled_parallelism_overrides, get_test_config_path + +# --------------------------------------------------------------------------- +# Shared minimal config overrides used across most tests +# --------------------------------------------------------------------------- +_BASE_CONFIG = { + "per_device_batch_size": 1.0, + "run_name": "nnx_decoder_test", + "enable_checkpointing": False, + "base_num_decoder_layers": 2, + "attention": "dot_product", + "max_target_length": 16, + "base_emb_dim": 256, + "base_num_query_heads": 2, + "base_num_kv_heads": 2, + "base_mlp_dim": 512, + "max_prefill_predict_length": 4, + "scan_layers": False, +} + + +def _make_config(**overrides): + """Return a pyconfig Config object suitable for unit tests.""" + extra_args = get_decoupled_parallelism_overrides() + return pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **_BASE_CONFIG, + **extra_args, + **overrides, + ) + + +def _make_mesh(cfg): + devices_array = maxtext_utils.create_device_mesh(cfg) + return Mesh(devices_array, cfg.mesh_axes) + + +# --------------------------------------------------------------------------- +# 1. deepstack_process +# --------------------------------------------------------------------------- + + +class TestDeepstackProcess(unittest.TestCase): + """Tests for the deepstack_process pure function.""" + + def _make_inputs(self, batch=2, seq_len=8, hidden_dim=16, num_visual=3, seed=0): + key = jax.random.PRNGKey(seed) + k1, k2 = jax.random.split(key) + hidden_states = jax.random.normal(k1, (batch, seq_len, hidden_dim)) + mask = jnp.zeros((batch, seq_len), dtype=bool).at[:, :num_visual].set(True) + visual_embeds = jax.random.normal(k2, (batch, num_visual, hidden_dim)) + return hidden_states, mask, visual_embeds + + def test_output_shape_matches_hidden_states(self): + """Output shape must equal input hidden_states shape.""" + hidden_states, mask, visual_embeds = self._make_inputs() + result = deepstack_process(hidden_states, mask, visual_embeds) + self.assertEqual(result.shape, hidden_states.shape) + + def test_unmasked_positions_are_unchanged(self): + """Positions outside the bidirectional mask must not be modified.""" + batch, seq_len, hidden_dim, num_visual = 1, 6, 8, 2 + hidden_states = jnp.ones((batch, seq_len, hidden_dim)) + mask = jnp.zeros((batch, seq_len), dtype=bool).at[:, :num_visual].set(True) + # Zero visual embeds ensure any addition at mask=True positions is a no-op + visual_embeds = jnp.zeros((batch, num_visual, hidden_dim)) + + result = deepstack_process(hidden_states, mask, visual_embeds) + + np.testing.assert_allclose( + np.array(result[:, num_visual:, :]), + np.ones((batch, seq_len - num_visual, hidden_dim)), + ) + + def test_masked_positions_receive_visual_embeds(self): + """Visual embeddings must be added at masked (True) positions.""" + batch, seq_len, hidden_dim, num_visual = 1, 4, 4, 2 + hidden_states = jnp.zeros((batch, seq_len, hidden_dim)) + mask = jnp.zeros((batch, seq_len), dtype=bool).at[:, :num_visual].set(True) + visual_embeds = jnp.ones((batch, num_visual, hidden_dim)) + + result = deepstack_process(hidden_states, mask, visual_embeds) + + # At masked positions: 0 + 1 = 1 + np.testing.assert_allclose( + np.array(result[:, :num_visual, :]), + np.ones((batch, num_visual, hidden_dim)), + ) + # At unmasked positions: unchanged (still 0) + np.testing.assert_allclose( + np.array(result[:, num_visual:, :]), + np.zeros((batch, seq_len - num_visual, hidden_dim)), + ) + + def test_zero_visual_embeds_leave_hidden_states_unchanged(self): + """When all visual embeddings are zero, output equals input.""" + hidden_states, mask, _ = self._make_inputs() + num_visual = 3 + batch = hidden_states.shape[0] + hidden_dim = hidden_states.shape[2] + zero_visual = jnp.zeros((batch, num_visual, hidden_dim)) + + result = deepstack_process(hidden_states, mask, zero_visual) + + np.testing.assert_allclose(np.array(result), np.array(hidden_states)) + + def test_all_positions_masked(self): + """Works correctly when every token position is a visual token.""" + batch, seq_len, hidden_dim = 1, 4, 8 + hidden_states = jnp.zeros((batch, seq_len, hidden_dim)) + mask = jnp.ones((batch, seq_len), dtype=bool) + visual_embeds = jnp.ones((batch, seq_len, hidden_dim)) * 2.0 + + result = deepstack_process(hidden_states, mask, visual_embeds) + + np.testing.assert_allclose( + np.array(result), + np.full((batch, seq_len, hidden_dim), 2.0), + ) + + def test_no_positions_masked(self): + """When no positions are masked, hidden states are unchanged.""" + batch, seq_len, hidden_dim, num_visual = 2, 6, 8, 1 + hidden_states = jnp.ones((batch, seq_len, hidden_dim)) + mask = jnp.zeros((batch, seq_len), dtype=bool) + visual_embeds = jnp.ones((batch, num_visual, hidden_dim)) * 99.0 + + result = deepstack_process(hidden_states, mask, visual_embeds) + + np.testing.assert_allclose(np.array(result), np.array(hidden_states)) + + +# --------------------------------------------------------------------------- +# 2. NNXDecoderLayer +# --------------------------------------------------------------------------- + + +class TestNNXDecoderLayer(unittest.TestCase): + """Tests for the NNXDecoderLayer NNX module.""" + + def setUp(self): + super().setUp() + self.cfg = _make_config() + self.mesh = _make_mesh(self.cfg) + self.rng = jax.random.PRNGKey(0) + + def _make_layer(self, model_mode=MODEL_MODE_TRAIN): + return NNXDecoderLayer( + config=self.cfg, + mesh=self.mesh, + model_mode=model_mode, + rngs=nnx.Rngs(params=0, dropout=1), + ) + + def _make_inputs(self): + cfg = self.cfg + batch = cfg.global_batch_size_to_train_on + seq_len = cfg.max_target_length + emb_dim = cfg.emb_dim + inputs = jax.random.normal(self.rng, (batch, seq_len, emb_dim)).astype(cfg.dtype) + segment_ids = jnp.full((batch, seq_len), DECODING_ACTIVE_SEQUENCE_INDICATOR) + positions = jnp.broadcast_to(jnp.arange(seq_len)[None], (batch, seq_len)) + return inputs, segment_ids, positions + + # --- instantiation --------------------------------------------------------- + + def test_has_pre_self_attention_norm(self): + layer = self._make_layer() + self.assertIsInstance(layer.pre_self_attention_norm, RMSNorm) + + def test_has_self_attention(self): + + layer = self._make_layer() + self.assertIsInstance(layer.self_attention, Attention) + + def test_has_mlp(self): + + layer = self._make_layer() + self.assertIsInstance(layer.mlp, linears.MlpBlock) + + # --- forward pass ---------------------------------------------------------- + + def test_forward_output_shape_train(self): + """Forward pass output shape matches input shape in train mode.""" + layer = self._make_layer(MODEL_MODE_TRAIN) + inputs, segment_ids, positions = self._make_inputs() + out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + self.assertEqual(out.shape, inputs.shape) + + def test_forward_output_dtype(self): + """Output dtype matches config dtype.""" + layer = self._make_layer() + inputs, segment_ids, positions = self._make_inputs() + out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + self.assertEqual(out.dtype, self.cfg.dtype) + + def test_forward_kv_cache_is_none_when_scan_layers_false(self): + """kv_cache return value is not None when scan_layers=False (non-scan returns cache).""" + # With scan_layers=False the layer returns (output, kv_cache). + # kv_cache may be None in train mode (no cache is populated); we just + # verify the call doesn't raise and returns a 2-tuple. + layer = self._make_layer() + inputs, segment_ids, positions = self._make_inputs() + result = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + def test_forward_deterministic_and_stochastic_consistent_shape(self): + """Output shape is the same regardless of the deterministic flag.""" + layer = self._make_layer() + inputs, segment_ids, positions = self._make_inputs() + out_det, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + out_stoch, _ = layer(inputs, segment_ids, positions, deterministic=False, model_mode=MODEL_MODE_TRAIN) + self.assertEqual(out_det.shape, out_stoch.shape) + + +# --------------------------------------------------------------------------- +# 3. NNXDecoder.get_decoder_layers +# --------------------------------------------------------------------------- + + +class TestNNXDecoderGetDecoderLayers(unittest.TestCase): + """Tests for NNXDecoder.get_decoder_layers.""" + + def setUp(self): + super().setUp() + self.cfg = _make_config() + self.mesh = _make_mesh(self.cfg) + + def _make_decoder(self, **cfg_overrides): + cfg = _make_config(**cfg_overrides) if cfg_overrides else self.cfg + mesh = _make_mesh(cfg) if cfg_overrides else self.mesh + return NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + + def test_default_decoder_block_returns_nnx_decoder_layer(self): + decoder = self._make_decoder(decoder_block=DecoderBlockType.DEFAULT) + layers = decoder.get_decoder_layers() + self.assertEqual(layers, [NNXDecoderLayer]) + + def test_get_decoder_layers_returns_list(self): + decoder = self._make_decoder() + result = decoder.get_decoder_layers() + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) + + def test_llama2_decoder_block(self): + + decoder = self._make_decoder(model_name="llama2-7b") + layers = decoder.get_decoder_layers() + self.assertEqual(layers, [LlamaDecoderLayer]) + + def test_get_decoder_layers_idempotent(self): + """Calling get_decoder_layers twice returns the same result.""" + decoder = self._make_decoder() + self.assertEqual(decoder.get_decoder_layers(), decoder.get_decoder_layers()) + + +# --------------------------------------------------------------------------- +# 4. NNXDecoder.get_norm_layer +# --------------------------------------------------------------------------- + + +class TestNNXDecoderGetNormLayer(unittest.TestCase): + """Tests for NNXDecoder.get_norm_layer.""" + + def setUp(self): + super().setUp() + self.cfg = _make_config() + self.mesh = _make_mesh(self.cfg) + self.decoder = NNXDecoder( + config=self.cfg, + mesh=self.mesh, + rngs=nnx.Rngs(params=0, dropout=1), + ) + + def test_default_returns_rms_norm(self): + """DEFAULT decoder block should use RMSNorm.""" + # get_norm_layer returns a functools.partial wrapping RMSNorm. + # The decoder_norm attribute is already instantiated via that partial. + self.assertIsInstance(self.decoder.decoder_norm, RMSNorm) + + def test_gpt3_returns_gpt3_layer_norm(self): + + cfg = _make_config(model_name="gpt3-52k") + mesh = _make_mesh(cfg) + decoder = NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + self.assertIsInstance(decoder.decoder_norm, Gpt3LayerNorm) + + +# --------------------------------------------------------------------------- +# 5. NNXDecoder.get_remat_policy / minimal_policy +# --------------------------------------------------------------------------- + + +class TestNNXDecoderRematPolicy(unittest.TestCase): + """Tests for NNXDecoder.get_remat_policy and minimal_policy.""" + + def setUp(self): + super().setUp() + self.cfg = _make_config(remat_policy="full") + self.mesh = _make_mesh(self.cfg) + self.decoder = NNXDecoder( + config=self.cfg, + mesh=self.mesh, + rngs=nnx.Rngs(params=0, dropout=1), + ) + + def test_remat_policy_none_returns_none(self): + self.assertIsNone(self.decoder.get_remat_policy()) + + def test_remat_policy_full_returns_none(self): + cfg = _make_config(remat_policy="full") + mesh = _make_mesh(cfg) + decoder = NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + self.assertIsNone(decoder.get_remat_policy()) + + def test_remat_policy_minimal_returns_non_none(self): + cfg = _make_config(remat_policy="minimal") + mesh = _make_mesh(cfg) + decoder = NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + self.assertIsNotNone(decoder.get_remat_policy()) + + def test_remat_policy_minimal_with_context_returns_non_none(self): + cfg = _make_config(remat_policy="minimal_with_context") + mesh = _make_mesh(cfg) + decoder = NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + self.assertIsNotNone(decoder.get_remat_policy()) + + def test_remat_policy_save_qkv_proj_returns_non_none(self): + cfg = _make_config(remat_policy="save_qkv_proj") + mesh = _make_mesh(cfg) + decoder = NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + self.assertIsNotNone(decoder.get_remat_policy()) + + def test_remat_policy_save_out_proj_returns_non_none(self): + cfg = _make_config(remat_policy="save_out_proj") + mesh = _make_mesh(cfg) + decoder = NNXDecoder(config=cfg, mesh=mesh, rngs=nnx.Rngs(params=0, dropout=1)) + self.assertIsNotNone(decoder.get_remat_policy()) + + # --- minimal_policy ------------------------------------------------------- + + def test_minimal_policy_no_flags(self): + policy = self.decoder.minimal_policy() + self.assertIsNotNone(policy) + + def test_minimal_policy_with_context(self): + policy = self.decoder.minimal_policy(with_context=True) + self.assertIsNotNone(policy) + + def test_minimal_policy_with_quantization(self): + policy = self.decoder.minimal_policy(with_quantization=True) + self.assertIsNotNone(policy) + + def test_minimal_policy_with_context_and_quantization(self): + policy = self.decoder.minimal_policy(with_context=True, with_quantization=True) + self.assertIsNotNone(policy) + + def test_minimal_policy_returns_distinct_objects_for_different_flags(self): + """Different flag combinations should produce different policy objects.""" + p1 = self.decoder.minimal_policy(with_context=False) + p2 = self.decoder.minimal_policy(with_context=True) + # They're different checkpoint policies; at minimum they're both non-None + # and Python objects that are not the same instance. + self.assertIsNotNone(p1) + self.assertIsNotNone(p2) + + +# --------------------------------------------------------------------------- +# 6. NNXDecoder full forward pass +# --------------------------------------------------------------------------- + + +class TestNNXDecoderForwardPass(unittest.TestCase): + """Integration-style test for NNXDecoder.__call__ in train mode.""" + + def setUp(self): + super().setUp() + self.cfg = _make_config() + self.mesh = _make_mesh(self.cfg) + self.rng = jax.random.PRNGKey(0) + self.rngs = nnx.Rngs(params=0, dropout=1) + + self.decoder = NNXDecoder( + config=self.cfg, + mesh=self.mesh, + model_mode=MODEL_MODE_TRAIN, + rngs=self.rngs, + ) + self.shared_embedding = Embed( + num_embeddings=self.cfg.vocab_size, + num_features=self.cfg.emb_dim, + dtype=self.cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=self.cfg, + mesh=self.mesh, + rngs=self.rngs, + ) + + def _make_token_inputs(self): + cfg = self.cfg + batch = cfg.global_batch_size_to_train_on + seq_len = cfg.max_target_length + ids = jax.random.randint(self.rng, (batch, seq_len), 0, cfg.vocab_size) + segment_ids = jnp.full((batch, seq_len), DECODING_ACTIVE_SEQUENCE_INDICATOR) + positions = jnp.broadcast_to(jnp.arange(seq_len)[None], (batch, seq_len)) + return ids, segment_ids, positions + + def test_forward_pass_returns_three_tuple(self): + """__call__ must return (logits, hidden_state, kv_caches).""" + ids, segment_ids, positions = self._make_token_inputs() + result = self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 3) + + def test_logits_shape(self): + """Logits shape: [batch, seq_len, vocab_size].""" + cfg = self.cfg + ids, segment_ids, positions = self._make_token_inputs() + logits, _, _ = self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + expected = (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size) + self.assertEqual(logits.shape, expected) + + def test_hidden_state_shape(self): + """hidden_state shape: [batch, seq_len, emb_dim].""" + cfg = self.cfg + ids, segment_ids, positions = self._make_token_inputs() + _, hidden_state, _ = self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + expected = (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.emb_dim) + self.assertEqual(hidden_state.shape, expected) + + def test_logits_are_finite(self): + """Logits must not contain NaN or Inf in a simple forward pass.""" + ids, segment_ids, positions = self._make_token_inputs() + logits, _, _ = self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertTrue(jnp.all(jnp.isfinite(logits))) + + def test_different_random_seeds_produce_different_logits(self): + """Two randomly-initialised decoders should not produce identical logits.""" + cfg = self.cfg + mesh = self.mesh + rngs2 = nnx.Rngs(params=99, dropout=1) + decoder2 = NNXDecoder(config=cfg, mesh=mesh, model_mode=MODEL_MODE_TRAIN, rngs=rngs2) + shared_emb2 = Embed( + num_embeddings=cfg.vocab_size, + num_features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=cfg, + mesh=mesh, + rngs=rngs2, + ) + ids, segment_ids, positions = self._make_token_inputs() + common_kwargs = { + "decoder_segment_ids": segment_ids, + "deterministic": True, + "model_mode": MODEL_MODE_TRAIN, + } + logits1, _, _ = self.decoder(self.shared_embedding, ids, positions, **common_kwargs) + logits2, _, _ = decoder2(shared_emb2, ids, positions, **common_kwargs) + self.assertFalse(jnp.allclose(logits1, logits2)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index cb291e13bd..10474239c6 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -636,6 +636,8 @@ def test_moe_deepseek_pipeline_subset(self): "pipeline_parallel_layers=56", "ici_expert_parallelism=16", "dcn_pipeline_parallelism=8", + "first_num_dense_layers=8", + "base_num_decoder_layers=72", ) ) @@ -653,7 +655,7 @@ def test_pipeline_subset(self): "per_device_batch_size=1", "max_target_length=1024", "pipeline_parallel_layers=56", - "base_num_decoder_layers=61", # Remainder of 5 will fail when sharded incorrectly. + "base_num_decoder_layers=64", # Must be divisible by dcn_pipeline_parallelism=8 in NNX scan path. "ici_expert_parallelism=16", "dcn_pipeline_parallelism=8", ) From 0b98f362fe61215d66f9010be0b0cff992c6a117 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Wed, 25 Mar 2026 04:24:06 +0000 Subject: [PATCH 3/5] Update --- src/maxtext/layers/nnx_decoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 647d6611a6..6fae5b55c1 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -71,7 +71,7 @@ class NNXDecoderLayer(nnx.Module): """ - Transformer decoder layer converted to NNX. + Transformer decoder layer converted to NNX """ def __init__( From 5ce876189128030504244615df455fd63d66716c Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Wed, 25 Mar 2026 12:51:22 +0800 Subject: [PATCH 4/5] fix --- src/maxtext/layers/nnx_decoders.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 6fae5b55c1..7cf1ba52b3 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -71,7 +71,7 @@ class NNXDecoderLayer(nnx.Module): """ - Transformer decoder layer converted to NNX + Transformer decoder layer converted to NNX. """ def __init__( @@ -529,7 +529,13 @@ def layer_fn(carry, scanned_vars): params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) scanned_state = nnx.State.merge(params, scanned_other) - return final_carry, nnx.merge(graphdef, scanned_state) + # Update the existing module in-place rather than creating a new one. + # Creating a new module via nnx.merge and reassigning (self.layers = new_module) + # would replace a child node in the NNX graph, which is detected as a graph + # structure mutation when the parent module is inside a JAX transformation + # (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity. + nnx.update(layers, scanned_state) + return final_carry, layers def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" From 2243e469c0195a0f8622954d338bb02af75b6880 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Wed, 25 Mar 2026 04:52:30 +0000 Subject: [PATCH 5/5] Update --- src/maxtext/layers/nnx_decoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 7cf1ba52b3..75baf581e2 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -71,7 +71,7 @@ class NNXDecoderLayer(nnx.Module): """ - Transformer decoder layer converted to NNX. + Transformer decoder layer converted to NNX """ def __init__(