Skip to content

Commit 2c16599

Browse files
committed
feat: initial implementation
1 parent 391731f commit 2c16599

2 files changed

Lines changed: 795 additions & 70 deletions

File tree

src/MaxText/layers/nnx_decoders.py

Lines changed: 182 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -16,37 +16,28 @@
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

19-
from typing import Any
2019
import functools
2120
import inspect
21+
from typing import Any
2222

2323
import jax
2424
import jax.numpy as jnp
25-
from jax.ad_checkpoint import checkpoint_name
26-
from jax.sharding import Mesh
27-
2825
from flax import linen as nn
2926
from flax import nnx
30-
from flax.nnx import wrappers as nnx_wrappers
27+
from jax.ad_checkpoint import checkpoint_name
28+
from jax.sharding import Mesh
3129

32-
from MaxText.configs.types import PositionalEmbedding
33-
from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT
34-
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
35-
from MaxText import max_logging
36-
from MaxText.sharding import create_sharding
30+
from MaxText import max_logging, maxtext_utils, multimodal_utils, sharding
31+
from MaxText.common_types import (
32+
EP_AS_CONTEXT,
33+
MODEL_MODE_AUTOREGRESSIVE,
34+
MODEL_MODE_PREFILL,
35+
MODEL_MODE_TRAIN,
36+
Config,
37+
DecoderBlockType,
38+
ShardMode,
39+
)
3740
from MaxText.inference import page_manager
38-
from MaxText.layers import linears
39-
from MaxText.layers import initializers
40-
from MaxText.layers import quantizations
41-
from MaxText import maxtext_utils
42-
from MaxText import multimodal_utils
43-
from MaxText import sharding
44-
from MaxText.layers.attentions import Attention
45-
from MaxText.layers.normalizations import RMSNorm
46-
from MaxText.layers.embeddings import Embed, attend_on_embedding
47-
from MaxText.layers.quantizations import AqtQuantization as Quant
48-
49-
# Import specific layer definitions (assuming these files exist)
5041
from MaxText.layers import (
5142
deepseek,
5243
deepseek_batchsplit,
@@ -55,86 +46,102 @@
5546
gemma3,
5647
gpt3,
5748
gpt_oss,
49+
initializers,
50+
linears,
5851
llama2,
5952
llama4,
6053
mistral,
6154
mixtral,
55+
nnx_wrappers,
56+
quantizations,
6257
qwen3,
6358
simple_layer,
6459
)
60+
from MaxText.layers import nnx_pipeline as pipeline
61+
62+
# Assumes these modules are adapted for NNX
63+
from MaxText.layers.attentions import Attention
64+
from MaxText.layers.embeddings import Embed, PositionalEmbedding, attend_on_embedding
65+
from MaxText.layers.normalizations import RMSNorm
66+
from MaxText.layers.quantizations import AqtQuantization as Quant
67+
from MaxText.sharding import create_sharding
6568

6669

6770
class NNXDecoderLayer(nnx.Module):
6871
"""
69-
Transformer decoder layer converted to NNX.
72+
Transformer decoder layer that attends to the encoder.
73+
This is the core, reusable building block for both the main model's
74+
decoder stack and the auxiliary MTP layers.
7075
"""
7176

7277
def __init__(
7378
self,
7479
config: Config,
7580
mesh: Mesh,
76-
model_mode: str,
77-
quant: None | Quant = None,
78-
name: str = "decoder_layer",
81+
quant: Quant | None = None,
82+
model_mode: str = MODEL_MODE_TRAIN,
7983
*,
8084
rngs: nnx.Rngs,
8185
):
8286
self.config = config
8387
self.mesh = mesh
84-
self.model_mode = model_mode
8588
self.quant = quant
89+
self.model_mode = model_mode
8690

87-
cfg = self.config
88-
91+
# Initialize Pre-Attention Norm
8992
self.pre_self_attention_norm = RMSNorm(
90-
num_features=cfg.emb_dim,
91-
dtype=cfg.dtype,
92-
weight_dtype=cfg.weight_dtype,
93-
epsilon=cfg.normalization_layer_epsilon,
93+
num_features=self.config.emb_dim,
94+
dtype=self.config.dtype,
95+
weight_dtype=self.config.weight_dtype,
96+
epsilon=self.config.normalization_layer_epsilon,
9497
kernel_axes=("norm",),
9598
rngs=rngs,
9699
)
97100

101+
# Initialize Attention
98102
self.self_attention = Attention(
99103
config=self.config,
100-
num_query_heads=cfg.num_query_heads,
101-
num_kv_heads=cfg.num_kv_heads,
102-
head_dim=cfg.head_dim,
103-
max_target_length=cfg.max_target_length,
104-
max_prefill_predict_length=cfg.max_prefill_predict_length,
105-
attention_kernel=cfg.attention,
106-
inputs_q_shape=(1, 1, cfg.emb_dim),
107-
inputs_kv_shape=(1, 1, cfg.emb_dim),
104+
num_query_heads=self.config.num_query_heads,
105+
num_kv_heads=self.config.num_kv_heads,
106+
head_dim=self.config.head_dim,
107+
max_target_length=self.config.max_target_length,
108+
max_prefill_predict_length=self.config.max_prefill_predict_length,
109+
attention_kernel=self.config.attention,
110+
inputs_q_shape=(1, 1, self.config.emb_dim),
111+
inputs_kv_shape=(1, 1, self.config.emb_dim),
108112
mesh=mesh,
109-
dtype=cfg.dtype,
110-
weight_dtype=cfg.weight_dtype,
111-
dropout_rate=cfg.dropout_rate,
112-
float32_qk_product=cfg.float32_qk_product,
113-
float32_logits=cfg.float32_logits,
113+
dtype=self.config.dtype,
114+
weight_dtype=self.config.weight_dtype,
115+
dropout_rate=self.config.dropout_rate,
116+
float32_qk_product=self.config.float32_qk_product,
117+
float32_logits=self.config.float32_logits,
114118
quant=self.quant,
115-
kv_quant=quantizations.configure_kv_quant(cfg),
116-
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
117-
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
118-
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
119-
reshape_q=cfg.reshape_q,
119+
kv_quant=quantizations.configure_kv_quant(config),
120+
prefill_cache_axis_order=tuple(map(int, self.config.prefill_cache_axis_order.split(","))),
121+
ar_cache_axis_order=tuple(map(int, self.config.ar_cache_axis_order.split(","))),
122+
compute_axis_order=tuple(map(int, self.config.compute_axis_order.split(","))),
123+
reshape_q=self.config.reshape_q,
120124
model_mode=model_mode,
125+
rngs=rngs,
121126
)
122127

123-
self.mlp = linears.MLPBlock(
124-
in_features=cfg.emb_dim,
125-
intermediate_dim=cfg.mlp_dim,
126-
activations=cfg.mlp_activations,
127-
intermediate_dropout_rate=cfg.dropout_rate,
128-
dtype=cfg.dtype,
129-
weight_dtype=cfg.weight_dtype,
128+
# Initialize MLP
129+
self.mlp = linears.MlpBlock(
130+
in_features=self.config.emb_dim,
131+
intermediate_dim=self.config.mlp_dim,
132+
activations=self.config.mlp_activations,
133+
intermediate_dropout_rate=self.config.dropout_rate,
134+
dtype=self.config.dtype,
135+
weight_dtype=self.config.weight_dtype,
130136
model_mode=model_mode,
131-
config=cfg,
137+
config=self.config,
132138
quant=self.quant,
133139
mesh=self.mesh,
134140
rngs=rngs,
135141
)
136142

137-
self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,))
143+
# Initialize Dropout
144+
self.dropout = linears.Dropout(rate=config.dropout_rate, rngs=rngs, broadcast_dims=(-2,))
138145

139146
def __call__(
140147
self,
@@ -191,19 +198,72 @@ def __call__(
191198
layer_output = next_layer_addition_dropped_out + inputs
192199
layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names)
193200

194-
if cfg.record_internal_nn_metrics:
195-
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
196-
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
201+
if self.config.record_internal_nn_metrics:
202+
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
203+
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
197204
self.sow(
198-
"intermediates",
205+
nnx.Intermediate,
199206
"activation_fraction_zero",
200207
jnp.sum(layer_output == 0) / jnp.size(layer_output),
201208
)
202209

203-
if cfg.scan_layers:
210+
if self.config.scan_layers:
204211
return layer_output, None
205-
else:
206-
return layer_output, kv_cache
212+
213+
return layer_output, kv_cache
214+
215+
216+
class NNXSequentialBlockDecoderLayers(nnx.Module):
217+
"""Sequential unscanned series of decoder layers."""
218+
219+
def __init__(
220+
self,
221+
decoder_layer: Any,
222+
num_decoder_layers: int,
223+
config: Config,
224+
mesh: Mesh,
225+
model_mode: str,
226+
rngs: nnx.Rngs,
227+
quant: Quant,
228+
**kwargs,
229+
):
230+
self.config = config
231+
self.num_decoder_layers = num_decoder_layers
232+
233+
layers_list = []
234+
235+
for _ in range(num_decoder_layers):
236+
layers_list.append(decoder_layer(config=config, mesh=mesh, model_mode=model_mode, rngs=rngs, quant=quant, **kwargs))
237+
self.layers = nnx.List(layers_list)
238+
239+
def __call__(
240+
self,
241+
inputs: jnp.ndarray,
242+
decoder_segment_ids,
243+
decoder_positions,
244+
deterministic: bool,
245+
model_mode,
246+
slot: int | None = None,
247+
page_state: Any | None = None, # page_manager.PageState
248+
) -> jnp.ndarray:
249+
250+
# Iterate over the pre-initialized layers
251+
for layer in self.layers:
252+
inputs = layer(
253+
inputs,
254+
decoder_segment_ids,
255+
decoder_positions,
256+
deterministic,
257+
model_mode,
258+
slot=slot,
259+
page_state=page_state,
260+
)
261+
262+
if self.config.scan_layers:
263+
inputs = inputs[0]
264+
if self.config.scan_layers:
265+
return inputs, None # pytype: disable=bad-return-type
266+
return inputs
207267

208268

209269
class NNXDecoder(nnx.Module):
@@ -239,7 +299,7 @@ def __init__(
239299
num_embeddings=config.trainable_position_size,
240300
num_features=config.emb_dim,
241301
dtype=config.dtype,
242-
embedding_init=nn.initializers.normal(stddev=1.0),
302+
embedding_init=nnx.initializers.normal(stddev=1.0),
243303
config=config,
244304
mesh=self.mesh,
245305
rngs=rngs,
@@ -263,9 +323,13 @@ def __init__(
263323
)
264324

265325
self.scanned_layers = None
326+
self.using_pipeline = config.using_pipeline_parallelism
266327
self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK
267328
self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3
268329

330+
if self.using_pipeline:
331+
self.pipeline_module = self.get_pipeline_stage_module(decoder_block_classes)
332+
269333
if self.config.scan_layers:
270334
if self.is_deepseek:
271335
assert len(decoder_block_classes) == 2
@@ -305,6 +369,45 @@ def __init__(
305369
for i in range(config.num_decoder_layers):
306370
self._create_and_register_layer(layer_cls, rngs, "layers", i)
307371

372+
def get_pipeline_stage_module(self, decoder_blocks):
373+
"""Creates the Pipeline module with the correct stage configuration."""
374+
cfg = self.config
375+
376+
def get_layer_to_pipeline(blocks, cfg):
377+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
378+
return blocks[1]
379+
else:
380+
return blocks[0]
381+
382+
base_stage_cls = get_layer_to_pipeline(decoder_blocks, cfg)
383+
384+
if cfg.num_layers_per_pipeline_stage == 1:
385+
stage_module = self._create_single_layer(base_stage_cls, self.rngs)
386+
elif cfg.scan_layers_per_stage:
387+
stage_module = self._create_scanned_layers(
388+
base_stage_cls,
389+
length=cfg.num_layers_per_pipeline_stage,
390+
rngs=self.rngs,
391+
)
392+
else:
393+
stage_module = NNXSequentialBlockDecoderLayers(
394+
decoder_layer=base_stage_cls,
395+
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
396+
config=cfg,
397+
mesh=self.mesh,
398+
model_mode=self.model_mode,
399+
rngs=self.rngs,
400+
quant=self.quant,
401+
)
402+
403+
return pipeline.Pipeline(
404+
config=cfg,
405+
layers=stage_module,
406+
mesh=self.mesh,
407+
remat_policy=self.get_remat_policy(),
408+
rngs=self.rngs, # Pipeline keeps original RNGs
409+
)
410+
308411
def _create_and_register_layer(self, layer_cls, rngs, base_name, i):
309412
attr_name = f"{base_name}_{i}"
310413
layer = self._create_single_layer(layer_cls, rngs)
@@ -337,7 +440,8 @@ def create_layer_fn(rng):
337440
# TODO: Handle this properly.
338441
try:
339442
nnx.split_rngs(rngs, splits=length)
340-
except: # pylint: disable=bare-except
443+
except Exception as e: # pylint: disable=bare-except
444+
max_logging.log(f"Warning: could not split rngs for scanned layers: {e}") # pylint: disable=logging-fstring-interpolation
341445
pass
342446

343447
layers_vmapped = nnx.vmap(
@@ -696,7 +800,15 @@ def __call__(
696800
if cfg.decoder_block == DecoderBlockType.GEMMA3:
697801
layer_kwargs["bidirectional_mask"] = bidirectional_mask
698802

699-
if cfg.scan_layers:
803+
if self.using_pipeline:
804+
if cfg.pipeline_fsdp_ag_once:
805+
logical_partition_spec = None
806+
else:
807+
logical_partition_spec = None
808+
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
809+
y = self.pipeline_module(y, *layer_args, logical_partition_spec=logical_partition_spec)
810+
811+
elif cfg.scan_layers:
700812
if self.is_deepseek:
701813
y, _ = self._apply_layers_sequentially(
702814
self.dense_stack, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs

0 commit comments

Comments
 (0)