Skip to content

Commit 29a9b74

Browse files
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent c24d321 commit 29a9b74

6 files changed

Lines changed: 1124 additions & 48 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,8 @@ position_id_per_seconds: 25
10741074
subslice_shape: ""
10751075

10761076
# NNX
1077-
enable_nnx: false
1077+
enable_nnx: True
1078+
pure_nnx_decoder: True
10781079

10791080
################################## Qwen3-Next Specific Configs ##################################
10801081
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ class HardwareAndMesh(BaseModel):
777777
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
778778
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
779779
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
780+
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
780781

781782

782783
class LayoutAndSharding(BaseModel):

src/maxtext/layers/multi_token_prediction.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN
25+
from maxtext.layers.nnx_decoders import NNXDecoderLayer
2526
from maxtext.utils.globals import EPS
2627
from maxtext.layers import nnx_wrappers
2728
from maxtext.layers.decoders import DecoderLayer
@@ -70,7 +71,7 @@ def __init__(
7071
config: Config,
7172
mesh: Mesh,
7273
layer_number: int,
73-
transformer_layer_module: Type[DecoderLayer],
74+
transformer_layer_module: Type[NNXDecoderLayer],
7475
*,
7576
rngs: nnx.Rngs,
7677
):
@@ -108,22 +109,12 @@ def __init__(
108109
rngs=rngs,
109110
)
110111
# Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically.
111-
mtp_transformer_layer = transformer_layer_module(
112+
self.transformer_layer = transformer_layer_module(
112113
config=cfg,
113114
mesh=mesh,
114115
model_mode=MODEL_MODE_TRAIN,
115116
name=f"mtp_{k}_transformer_layer",
116-
)
117-
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)
118-
119-
# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
120-
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN)
121-
self.transformer_layer.lazy_init(
122-
inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype),
123-
decoder_segment_ids=None,
124-
decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32),
125-
deterministic=True,
126-
model_mode=MODEL_MODE_TRAIN,
117+
rngs=rngs,
127118
)
128119

129120
@property

0 commit comments

Comments
 (0)