1616# pylint: disable=arguments-differ
1717# pylint: disable=no-name-in-module
1818
19- from typing import Any
2019import functools
2120import inspect
21+ from typing import Any
2222
2323import jax
2424import jax .numpy as jnp
25- from jax .ad_checkpoint import checkpoint_name
26- from jax .sharding import Mesh
27-
2825from flax import linen as nn
2926from flax import nnx
30- from flax .nnx import wrappers as nnx_wrappers
27+ from jax .ad_checkpoint import checkpoint_name
28+ from jax .sharding import Mesh
3129
32- from MaxText .configs .types import PositionalEmbedding
33- from MaxText .common_types import DecoderBlockType , ShardMode , Config , EP_AS_CONTEXT
34- from MaxText .common_types import MODEL_MODE_TRAIN , MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE
35- from MaxText import max_logging
36- from MaxText .sharding import create_sharding
30+ from MaxText import max_logging , maxtext_utils , multimodal_utils , sharding
31+ from MaxText .common_types import (
32+ EP_AS_CONTEXT ,
33+ MODEL_MODE_AUTOREGRESSIVE ,
34+ MODEL_MODE_PREFILL ,
35+ MODEL_MODE_TRAIN ,
36+ Config ,
37+ DecoderBlockType ,
38+ ShardMode ,
39+ )
3740from MaxText .inference import page_manager
38- from MaxText .layers import linears
39- from MaxText .layers import initializers
40- from MaxText .layers import quantizations
41- from MaxText import maxtext_utils
42- from MaxText import multimodal_utils
43- from MaxText import sharding
44- from MaxText .layers .attentions import Attention
45- from MaxText .layers .normalizations import RMSNorm
46- from MaxText .layers .embeddings import Embed , attend_on_embedding
47- from MaxText .layers .quantizations import AqtQuantization as Quant
48-
49- # Import specific layer definitions (assuming these files exist)
5041from MaxText .layers import (
5142 deepseek ,
5243 deepseek_batchsplit ,
5546 gemma3 ,
5647 gpt3 ,
5748 gpt_oss ,
49+ initializers ,
50+ linears ,
5851 llama2 ,
5952 llama4 ,
6053 mistral ,
6154 mixtral ,
55+ nnx_wrappers ,
56+ quantizations ,
6257 qwen3 ,
6358 simple_layer ,
6459)
60+ from MaxText .layers import nnx_pipeline as pipeline
61+
62+ # Assumes these modules are adapted for NNX
63+ from MaxText .layers .attentions import Attention
64+ from MaxText .layers .embeddings import Embed , PositionalEmbedding , attend_on_embedding
65+ from MaxText .layers .normalizations import RMSNorm
66+ from MaxText .layers .quantizations import AqtQuantization as Quant
67+ from MaxText .sharding import create_sharding
6568
6669
6770class NNXDecoderLayer (nnx .Module ):
6871 """
69- Transformer decoder layer converted to NNX.
72+ Transformer decoder layer that attends to the encoder.
73+ This is the core, reusable building block for both the main model's
74+ decoder stack and the auxiliary MTP layers.
7075 """
7176
7277 def __init__ (
7378 self ,
7479 config : Config ,
7580 mesh : Mesh ,
76- model_mode : str ,
77- quant : None | Quant = None ,
78- name : str = "decoder_layer" ,
81+ quant : Quant | None = None ,
82+ model_mode : str = MODEL_MODE_TRAIN ,
7983 * ,
8084 rngs : nnx .Rngs ,
8185 ):
8286 self .config = config
8387 self .mesh = mesh
84- self .model_mode = model_mode
8588 self .quant = quant
89+ self .model_mode = model_mode
8690
87- cfg = self .config
88-
91+ # Initialize Pre-Attention Norm
8992 self .pre_self_attention_norm = RMSNorm (
90- num_features = cfg .emb_dim ,
91- dtype = cfg .dtype ,
92- weight_dtype = cfg .weight_dtype ,
93- epsilon = cfg .normalization_layer_epsilon ,
93+ num_features = self . config .emb_dim ,
94+ dtype = self . config .dtype ,
95+ weight_dtype = self . config .weight_dtype ,
96+ epsilon = self . config .normalization_layer_epsilon ,
9497 kernel_axes = ("norm" ,),
9598 rngs = rngs ,
9699 )
97100
101+ # Initialize Attention
98102 self .self_attention = Attention (
99103 config = self .config ,
100- num_query_heads = cfg .num_query_heads ,
101- num_kv_heads = cfg .num_kv_heads ,
102- head_dim = cfg .head_dim ,
103- max_target_length = cfg .max_target_length ,
104- max_prefill_predict_length = cfg .max_prefill_predict_length ,
105- attention_kernel = cfg .attention ,
106- inputs_q_shape = (1 , 1 , cfg .emb_dim ),
107- inputs_kv_shape = (1 , 1 , cfg .emb_dim ),
104+ num_query_heads = self . config .num_query_heads ,
105+ num_kv_heads = self . config .num_kv_heads ,
106+ head_dim = self . config .head_dim ,
107+ max_target_length = self . config .max_target_length ,
108+ max_prefill_predict_length = self . config .max_prefill_predict_length ,
109+ attention_kernel = self . config .attention ,
110+ inputs_q_shape = (1 , 1 , self . config .emb_dim ),
111+ inputs_kv_shape = (1 , 1 , self . config .emb_dim ),
108112 mesh = mesh ,
109- dtype = cfg .dtype ,
110- weight_dtype = cfg .weight_dtype ,
111- dropout_rate = cfg .dropout_rate ,
112- float32_qk_product = cfg .float32_qk_product ,
113- float32_logits = cfg .float32_logits ,
113+ dtype = self . config .dtype ,
114+ weight_dtype = self . config .weight_dtype ,
115+ dropout_rate = self . config .dropout_rate ,
116+ float32_qk_product = self . config .float32_qk_product ,
117+ float32_logits = self . config .float32_logits ,
114118 quant = self .quant ,
115- kv_quant = quantizations .configure_kv_quant (cfg ),
116- prefill_cache_axis_order = tuple (map (int , cfg .prefill_cache_axis_order .split ("," ))),
117- ar_cache_axis_order = tuple (map (int , cfg .ar_cache_axis_order .split ("," ))),
118- compute_axis_order = tuple (map (int , cfg .compute_axis_order .split ("," ))),
119- reshape_q = cfg .reshape_q ,
119+ kv_quant = quantizations .configure_kv_quant (config ),
120+ prefill_cache_axis_order = tuple (map (int , self . config .prefill_cache_axis_order .split ("," ))),
121+ ar_cache_axis_order = tuple (map (int , self . config .ar_cache_axis_order .split ("," ))),
122+ compute_axis_order = tuple (map (int , self . config .compute_axis_order .split ("," ))),
123+ reshape_q = self . config .reshape_q ,
120124 model_mode = model_mode ,
125+ rngs = rngs ,
121126 )
122127
123- self .mlp = linears .MLPBlock (
124- in_features = cfg .emb_dim ,
125- intermediate_dim = cfg .mlp_dim ,
126- activations = cfg .mlp_activations ,
127- intermediate_dropout_rate = cfg .dropout_rate ,
128- dtype = cfg .dtype ,
129- weight_dtype = cfg .weight_dtype ,
128+ # Initialize MLP
129+ self .mlp = linears .MlpBlock (
130+ in_features = self .config .emb_dim ,
131+ intermediate_dim = self .config .mlp_dim ,
132+ activations = self .config .mlp_activations ,
133+ intermediate_dropout_rate = self .config .dropout_rate ,
134+ dtype = self .config .dtype ,
135+ weight_dtype = self .config .weight_dtype ,
130136 model_mode = model_mode ,
131- config = cfg ,
137+ config = self . config ,
132138 quant = self .quant ,
133139 mesh = self .mesh ,
134140 rngs = rngs ,
135141 )
136142
137- self .dropout = linears .Dropout (rate = cfg .dropout_rate , rngs = rngs , broadcast_dims = (- 2 ,))
143+ # Initialize Dropout
144+ self .dropout = linears .Dropout (rate = config .dropout_rate , rngs = rngs , broadcast_dims = (- 2 ,))
138145
139146 def __call__ (
140147 self ,
@@ -191,19 +198,72 @@ def __call__(
191198 layer_output = next_layer_addition_dropped_out + inputs
192199 layer_output = _maybe_shard_with_logical (layer_output , logical_axis_names )
193200
194- if cfg .record_internal_nn_metrics :
195- self .sow ("intermediates" , "activation_mean" , jnp .mean (layer_output ))
196- self .sow ("intermediates" , "activation_stdev" , jnp .std (layer_output ))
201+ if self . config .record_internal_nn_metrics :
202+ self .sow (nnx . Intermediate , "activation_mean" , jnp .mean (layer_output ))
203+ self .sow (nnx . Intermediate , "activation_stdev" , jnp .std (layer_output ))
197204 self .sow (
198- "intermediates" ,
205+ nnx . Intermediate ,
199206 "activation_fraction_zero" ,
200207 jnp .sum (layer_output == 0 ) / jnp .size (layer_output ),
201208 )
202209
203- if cfg .scan_layers :
210+ if self . config .scan_layers :
204211 return layer_output , None
205- else :
206- return layer_output , kv_cache
212+
213+ return layer_output , kv_cache
214+
215+
216+ class NNXSequentialBlockDecoderLayers (nnx .Module ):
217+ """Sequential unscanned series of decoder layers."""
218+
219+ def __init__ (
220+ self ,
221+ decoder_layer : Any ,
222+ num_decoder_layers : int ,
223+ config : Config ,
224+ mesh : Mesh ,
225+ model_mode : str ,
226+ rngs : nnx .Rngs ,
227+ quant : Quant ,
228+ ** kwargs ,
229+ ):
230+ self .config = config
231+ self .num_decoder_layers = num_decoder_layers
232+
233+ layers_list = []
234+
235+ for _ in range (num_decoder_layers ):
236+ layers_list .append (decoder_layer (config = config , mesh = mesh , model_mode = model_mode , rngs = rngs , quant = quant , ** kwargs ))
237+ self .layers = nnx .List (layers_list )
238+
239+ def __call__ (
240+ self ,
241+ inputs : jnp .ndarray ,
242+ decoder_segment_ids ,
243+ decoder_positions ,
244+ deterministic : bool ,
245+ model_mode ,
246+ slot : int | None = None ,
247+ page_state : Any | None = None , # page_manager.PageState
248+ ) -> jnp .ndarray :
249+
250+ # Iterate over the pre-initialized layers
251+ for layer in self .layers :
252+ inputs = layer (
253+ inputs ,
254+ decoder_segment_ids ,
255+ decoder_positions ,
256+ deterministic ,
257+ model_mode ,
258+ slot = slot ,
259+ page_state = page_state ,
260+ )
261+
262+ if self .config .scan_layers :
263+ inputs = inputs [0 ]
264+ if self .config .scan_layers :
265+ return inputs , None # pytype: disable=bad-return-type
266+ return inputs
207267
208268
209269class NNXDecoder (nnx .Module ):
@@ -239,7 +299,7 @@ def __init__(
239299 num_embeddings = config .trainable_position_size ,
240300 num_features = config .emb_dim ,
241301 dtype = config .dtype ,
242- embedding_init = nn .initializers .normal (stddev = 1.0 ),
302+ embedding_init = nnx .initializers .normal (stddev = 1.0 ),
243303 config = config ,
244304 mesh = self .mesh ,
245305 rngs = rngs ,
@@ -263,9 +323,13 @@ def __init__(
263323 )
264324
265325 self .scanned_layers = None
326+ self .using_pipeline = config .using_pipeline_parallelism
266327 self .is_deepseek = self .config .decoder_block == DecoderBlockType .DEEPSEEK
267328 self .is_gemma3 = self .config .decoder_block == DecoderBlockType .GEMMA3
268329
330+ if self .using_pipeline :
331+ self .pipeline_module = self .get_pipeline_stage_module (decoder_block_classes )
332+
269333 if self .config .scan_layers :
270334 if self .is_deepseek :
271335 assert len (decoder_block_classes ) == 2
@@ -305,6 +369,45 @@ def __init__(
305369 for i in range (config .num_decoder_layers ):
306370 self ._create_and_register_layer (layer_cls , rngs , "layers" , i )
307371
372+ def get_pipeline_stage_module (self , decoder_blocks ):
373+ """Creates the Pipeline module with the correct stage configuration."""
374+ cfg = self .config
375+
376+ def get_layer_to_pipeline (blocks , cfg ):
377+ if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
378+ return blocks [1 ]
379+ else :
380+ return blocks [0 ]
381+
382+ base_stage_cls = get_layer_to_pipeline (decoder_blocks , cfg )
383+
384+ if cfg .num_layers_per_pipeline_stage == 1 :
385+ stage_module = self ._create_single_layer (base_stage_cls , self .rngs )
386+ elif cfg .scan_layers_per_stage :
387+ stage_module = self ._create_scanned_layers (
388+ base_stage_cls ,
389+ length = cfg .num_layers_per_pipeline_stage ,
390+ rngs = self .rngs ,
391+ )
392+ else :
393+ stage_module = NNXSequentialBlockDecoderLayers (
394+ decoder_layer = base_stage_cls ,
395+ num_decoder_layers = cfg .num_layers_per_pipeline_stage ,
396+ config = cfg ,
397+ mesh = self .mesh ,
398+ model_mode = self .model_mode ,
399+ rngs = self .rngs ,
400+ quant = self .quant ,
401+ )
402+
403+ return pipeline .Pipeline (
404+ config = cfg ,
405+ layers = stage_module ,
406+ mesh = self .mesh ,
407+ remat_policy = self .get_remat_policy (),
408+ rngs = self .rngs , # Pipeline keeps original RNGs
409+ )
410+
308411 def _create_and_register_layer (self , layer_cls , rngs , base_name , i ):
309412 attr_name = f"{ base_name } _{ i } "
310413 layer = self ._create_single_layer (layer_cls , rngs )
@@ -337,7 +440,8 @@ def create_layer_fn(rng):
337440 # TODO: Handle this properly.
338441 try :
339442 nnx .split_rngs (rngs , splits = length )
340- except : # pylint: disable=bare-except
443+ except Exception as e : # pylint: disable=bare-except
444+ max_logging .log (f"Warning: could not split rngs for scanned layers: { e } " ) # pylint: disable=logging-fstring-interpolation
341445 pass
342446
343447 layers_vmapped = nnx .vmap (
@@ -696,7 +800,15 @@ def __call__(
696800 if cfg .decoder_block == DecoderBlockType .GEMMA3 :
697801 layer_kwargs ["bidirectional_mask" ] = bidirectional_mask
698802
699- if cfg .scan_layers :
803+ if self .using_pipeline :
804+ if cfg .pipeline_fsdp_ag_once :
805+ logical_partition_spec = None
806+ else :
807+ logical_partition_spec = None
808+ layer_args = (decoder_segment_ids , decoder_positions , deterministic , model_mode )
809+ y = self .pipeline_module (y , * layer_args , logical_partition_spec = logical_partition_spec )
810+
811+ elif cfg .scan_layers :
700812 if self .is_deepseek :
701813 y , _ = self ._apply_layers_sequentially (
702814 self .dense_stack , y , * layer_args , length = cfg .first_num_dense_layers , ** layer_kwargs
0 commit comments