7171
7272class NNXDecoderLayer (nnx .Module ):
7373 """
74- Transformer decoder layer converted to NNX.
74+ Transformer decoder layer converted to NNX
7575 """
7676
7777 def __init__ (
@@ -307,11 +307,10 @@ def __init__(
307307 dense_cls , moe_cls = decoder_block_classes
308308
309309 num_dense = config .first_num_dense_layers
310- self .dense_layers = self ._create_scanned_layers (dense_cls , length = num_dense , rngs = rngs )
311-
310+ self .dense_layers = self ._create_scanned_layers (dense_cls , length = num_dense , metadata_axis_name = "dense_layers" , rngs = rngs )
312311 num_moe = config .num_decoder_layers - config .first_num_dense_layers
313-
314- self . moe_layers = self . _create_scanned_layers ( moe_cls , length = num_moe , rngs = rngs )
312+ self . moe_layers = self . _create_scanned_layers ( moe_cls , length = num_moe , metadata_axis_name = "moe_layers" , rngs = rngs )
313+
315314 elif self .is_gemma3 :
316315 attention_pattern_length = len (gemma3 .GEMMA3_ATTENTION_PATTERN )
317316 scan_length = config .num_decoder_layers // attention_pattern_length
@@ -323,7 +322,9 @@ def __init__(
323322 RemattedGemma3Block = gemma3 .Gemma3ScannableBlock
324323
325324 if scan_length > 0 :
326- self .layers = self ._create_scanned_layers (RemattedGemma3Block , length = scan_length , rngs = rngs , ** layer_kwargs )
325+ self .layers = self ._create_scanned_layers (
326+ RemattedGemma3Block , length = scan_length , metadata_axis_name = "layers" , rngs = rngs , ** layer_kwargs
327+ )
327328 self .layers_remainder = RemattedGemma3Block (
328329 config = self .config , mesh = mesh , quant = self .quant , model_mode = self .model_mode , ** rem_layer_kwargs , rngs = rngs
329330 ) # pytype: disable=wrong-keyword-args
@@ -338,7 +339,9 @@ def __init__(
338339 }
339340
340341 if num_layers > 0 :
341- self .layers = self ._create_scanned_layers (layer_cls , length = num_layers , rngs = rngs , ** layer_kwargs )
342+ self .layers = self ._create_scanned_layers (
343+ layer_cls , length = num_layers , metadata_axis_name = "layers" , rngs = rngs , ** layer_kwargs
344+ )
342345 else :
343346 self .layers = nnx .List ([])
344347
@@ -390,34 +393,80 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
390393 )
391394 return nnx_wrappers .ToNNX (layer_linen , rngs = rngs )
392395
393- def _create_scanned_layers (self , decoder_layer_class , length : int , rngs : nnx .Rngs , ** layer_kwargs ):
394- """Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
396+ def _create_scanned_layers (self , decoder_layer_class , length : int , metadata_axis_name : str , rngs : nnx .Rngs , ** layer_kwargs ):
397+ """Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization.
395398
396- def create_layer_fn ( rng ):
397- layer = decoder_layer_class (
398- config = self . config , mesh = self . mesh , quant = self . quant , model_mode = self . model_mode , rngs = rng , ** layer_kwargs
399- )
400-
401- return layer
399+ Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization.
400+ With vmap, all layers' parameters are created simultaneously (O(N) peak memory).
401+ With scan, parameters are created one layer at a time (O(1) peak intermediate memory),
402+ which prevents OOM on memory-constrained devices like TPU v6e-4.
403+ """
404+ scan_axis = self . config . param_scan_axis
402405
403- # Workaround for Deepseek MTP test failure.
404- # TODO: Handle this properly.
406+ # Fork rngs to get per-layer RNG states for scanning
405407 try :
406408 forked_rngs = rngs .fork (split = length )
407-
408409 except : # pylint: disable=bare-except
409410 pass
410411
411- out_axes = nnx .StateAxes ({nnx .Param : self .config .param_scan_axis , ...: 0 })
412- layers_vmapped = nnx .vmap (
413- create_layer_fn ,
414- in_axes = 0 ,
415- out_axes = out_axes ,
416- axis_name = "layers" ,
417- transform_metadata = {nnx .PARTITION_NAME : "layers" },
418- )(forked_rngs )
412+ rngs_graphdef , rngs_state = nnx .split (forked_rngs )
413+
414+ # Create a reference layer to capture the module graph structure (graphdef).
415+ # This layer's params are discarded — only the structure is kept.
416+ # Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the
417+ # graphdef has the same number of RNG state leaves as the scan-created layers.
418+ first_rng_state = jax .tree .map (lambda x : x [0 ], rngs_state )
419+ ref_rngs = nnx .merge (rngs_graphdef , first_rng_state )
420+ ref_layer = decoder_layer_class (
421+ config = self .config , mesh = self .mesh , quant = self .quant ,
422+ model_mode = self .model_mode , rngs = ref_rngs , ** layer_kwargs
423+ )
424+ layer_graphdef , _ , _ = nnx .split (ref_layer , nnx .Param , ...)
425+ del ref_layer
426+
427+ # Sequentially create each layer's parameters via jax.lax.scan.
428+ # The scan body is traced once; XLA executes it N times with different RNG keys,
429+ # keeping only one layer's intermediate state alive at a time.
430+ def scan_body (carry , rng_state_slice ):
431+ layer_rngs = nnx .merge (rngs_graphdef , rng_state_slice )
432+ layer = decoder_layer_class (
433+ config = self .config , mesh = self .mesh , quant = self .quant ,
434+ model_mode = self .model_mode , rngs = layer_rngs , ** layer_kwargs
435+ )
436+ _ , params , rest = nnx .split (layer , nnx .Param , ...)
437+ return carry , (params , rest )
438+
439+ _ , (stacked_params , stacked_rest ) = jax .lax .scan (scan_body , None , rngs_state )
419440
420- return layers_vmapped
441+ # jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis.
442+ if scan_axis != 0 :
443+ stacked_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), stacked_params )
444+
445+ # Add partition metadata that nnx.vmap's transform_metadata would normally set.
446+ # This metadata is read by variable_to_logically_partitioned() in initializers.py
447+ # and by nnx.get_partition_spec() (via the updated out_sharding) to produce
448+ # correct sharding specs that include the scan axis dimension.
449+ def _add_scan_metadata (state , axis ):
450+ def _update_leaf (leaf ):
451+ if isinstance (leaf , nnx .VariableState ):
452+ metadata = leaf .get_metadata ()
453+ metadata [nnx .PARTITION_NAME ] = metadata_axis_name
454+ metadata ["param_scan_axis" ] = axis
455+ # Insert the scan axis name into out_sharding so that
456+ # nnx.get_partition_spec returns specs matching the actual tensor rank.
457+ # Without this, scanned params are 3D but specs remain 2D.
458+ if "out_sharding" in metadata and metadata ["out_sharding" ]:
459+ sharding = list (metadata ["out_sharding" ])
460+ sharding .insert (axis , metadata_axis_name )
461+ metadata ["out_sharding" ] = tuple (sharding )
462+ return leaf .replace (** metadata )
463+ return leaf
464+ return jax .tree .map (_update_leaf , state , is_leaf = lambda x : isinstance (x , nnx .VariableState ))
465+
466+ stacked_params = _add_scan_metadata (stacked_params , scan_axis )
467+ stacked_rest = _add_scan_metadata (stacked_rest , 0 )
468+
469+ return nnx .merge (layer_graphdef , stacked_params , stacked_rest )
421470
422471 def _apply_layer_with_remat (self , layer : nnx .Module , y : jax .Array , policy : Any , prevent_cse : bool , ** kwargs ):
423472 """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""
@@ -439,9 +488,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
439488 """Runs the layer stack using nnx.scan."""
440489 policy = self .get_remat_policy ()
441490 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
442- graphdef , params , state = nnx .split (
443- layers , nnx .Param , ...
444- ) # state: the mutable state we carry (KV cache, RNGs, etc.)
491+ graphdef , params , state = nnx .split (layers , nnx .Param , ...)
445492
446493 scan_axis = self .config .param_scan_axis
447494 if scan_axis != 0 :
@@ -451,6 +498,13 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
451498 sig = inspect .signature (layer_cls .__call__ )
452499 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
453500
501+ def _extract_matching_state (template , full ):
502+ if isinstance (template , nnx .State ):
503+ return nnx .State ({k : _extract_matching_state (v , full [k ]) for k , v in template .items ()})
504+ elif isinstance (template , dict ):
505+ return {k : _extract_matching_state (v , full [k ]) for k , v in template .items ()}
506+ return full
507+
454508 def layer_fn (carry , scanned_vars ):
455509 current_params , current_state = scanned_vars
456510
@@ -460,20 +514,28 @@ def layer_fn(carry, scanned_vars):
460514 layer = nnx .merge (graphdef , current_params , current_state )
461515 layer_out = layer (carry , * args , ** valid_kwargs )
462516 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
463- new_current_state = nnx .state (layer )
464-
517+
518+ new_full_state = nnx .state (layer )
519+ new_current_state = _extract_matching_state (current_state , new_full_state )
520+
521+ # ONLY return non-param state to prevent memory duplication of weights
465522 return new_carry , new_current_state
466523
467524 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
468525
469- final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
526+ final_carry , scanned_other = jax .lax .scan (layer_fn , x_in , (params , state ))
470527
471528 if scan_axis != 0 :
472- scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
473- scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
474- scanned_state = nnx .State .merge (scanned_params , scanned_other )
475-
476- return final_carry , nnx .merge (graphdef , scanned_state )
529+ params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), params )
530+
531+ scanned_state = nnx .State .merge (params , scanned_other )
532+ # Update the existing module in-place rather than creating a new one.
533+ # Creating a new module via nnx.merge and reassigning (self.layers = new_module)
534+ # would replace a child node in the NNX graph, which is detected as a graph
535+ # structure mutation when the parent module is inside a JAX transformation
536+ # (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity.
537+ nnx .update (layers , scanned_state )
538+ return final_carry , layers
477539
478540 def get_decoder_layers (self ):
479541 """Retrieves decoder layer classes based on config using a dictionary lookup."""
@@ -1159,7 +1221,7 @@ def decoder_as_linen(
11591221 model_mode : str ,
11601222 quant : None | Quant = None ,
11611223):
1162- """Creates a Decoder module. """
1224+ """Creates a Decoder module"""
11631225 module = nnx_wrappers .to_linen (
11641226 NNXDecoder ,
11651227 config = config ,
@@ -1171,4 +1233,4 @@ def decoder_as_linen(
11711233 abstract_init = False ,
11721234 metadata_fn = initializers .variable_to_logically_partitioned ,
11731235 )
1174- return module
1236+ return module
0 commit comments