Skip to content

Commit f0ddf63

Browse files
Migrate Decoder to NNX
1 parent de51021 commit f0ddf63

11 files changed

Lines changed: 746 additions & 100 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,8 +1111,8 @@ position_id_per_seconds: 25
11111111
subslice_shape: ""
11121112

11131113
# NNX
1114-
enable_nnx: false
1115-
pure_nnx_decoder: false
1114+
enable_nnx: True
1115+
pure_nnx_decoder: True
11161116

11171117
################################## Qwen3-Next Specific Configs ##################################
11181118
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,14 +534,14 @@ def __init__(
534534
elif self.is_qwen3_next:
535535
self.query_norm = Qwen3NextRMSNorm(
536536
num_features=self.config.head_dim,
537-
eps=self.config.normalization_layer_epsilon,
537+
epsilon=self.config.normalization_layer_epsilon,
538538
dtype=self.config.dtype,
539539
weight_dtype=self.config.weight_dtype,
540540
rngs=self.rngs,
541541
)
542542
self.key_norm = Qwen3NextRMSNorm(
543543
num_features=self.config.head_dim,
544-
eps=self.config.normalization_layer_epsilon,
544+
epsilon=self.config.normalization_layer_epsilon,
545545
dtype=self.config.dtype,
546546
weight_dtype=self.config.weight_dtype,
547547
rngs=self.rngs,

src/maxtext/layers/initializers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9494
out_sharding = metadata["sharding"]
9595

9696
if out_sharding is not None:
97+
if nnx.PARTITION_NAME in metadata:
98+
partition_name = metadata[nnx.PARTITION_NAME]
99+
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
100+
101+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
102+
if partition_name not in sharding_list:
103+
sharding_list.insert(scan_axis, partition_name)
104+
105+
out_sharding = tuple(sharding_list)
106+
97107
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
98108
variable.value,
99109
out_sharding, # type: ignore[arg-type]

src/maxtext/layers/nnx_decoders.py

Lines changed: 172 additions & 75 deletions
Large diffs are not rendered by default.

src/maxtext/layers/normalizations.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
102102
return y_flat.reshape(input_shape)
103103

104104

105-
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
105+
def Qwen3NextRMSNorm(
106+
num_features: int,
107+
epsilon: float,
108+
dtype: DType,
109+
weight_dtype: DType,
110+
shard_mode: ShardMode = ShardMode.AUTO,
111+
kernel_axes: tuple[None | str, ...] = (),
112+
parameter_memory_host_offload: bool = False,
113+
*,
114+
rngs: nnx.Rngs,
115+
):
106116
"""
107117
Used for input and post attention layernorms
108118
in Qwen3NextDecoderLayer.
@@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
115125
return nnx.data(
116126
RMSNorm(
117127
num_features=num_features,
118-
epsilon=eps,
128+
epsilon=epsilon,
119129
dtype=dtype,
120130
weight_dtype=weight_dtype,
131+
shard_mode=shard_mode,
132+
kernel_axes=kernel_axes,
121133
scale_init=linen_initializers.zeros,
134+
parameter_memory_host_offload=parameter_memory_host_offload,
122135
scale_offset=1.0,
123136
rngs=rngs,
124137
)

src/maxtext/layers/quantizations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aqt.jax.v2 import tiled_dot_general
2727
from aqt.jax.v2 import calibration
2828

29+
from maxtext.layers import nnx_wrappers
2930
import qwix
3031
from qwix._src.core import dot_general_qt
3132

@@ -285,7 +286,7 @@ class Fp8Quantization(Quantization):
285286

286287
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
287288
"""Returns dot_general configured with aqt params."""
288-
return nn.Fp8DirectDotGeneralOp
289+
return nnx_wrappers.ToNNX(nn.Fp8DirectDotGeneralOp)
289290

290291
def einsum(self, dtype: DType = jnp.float32):
291292
return _Fp8EinsumWrapper(dtype=dtype)

src/maxtext/models/models.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from maxtext.layers.decoders import Decoder
3434
from maxtext.layers.embeddings import Embed, embed_as_linen
3535
from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen
36-
from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen
36+
from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen
3737
from maxtext.layers.quantizations import AqtQuantization as Quant
3838
from maxtext.multimodal import processor as mm_processor
3939
from maxtext.utils import max_utils
@@ -376,25 +376,12 @@ def __init__(
376376
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
377377
# By convention, this is the last layer in the list.
378378
mtp_layer = layer_types[-1]
379-
mtp_block_linen = multi_token_prediction_block_as_linen(
379+
self.mtp_block = MultiTokenPredictionBlock(
380380
config=self.config,
381381
mesh=self.mesh,
382382
transformer_layer_module=mtp_layer,
383383
decoder=self.decoder,
384384
rngs=rngs,
385-
name="mtp_block",
386-
)
387-
self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs)
388-
389-
self.mtp_block.lazy_init(
390-
shared_embedding=self.token_embedder,
391-
main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype),
392-
input_ids=jnp.ones((1, 1), dtype=jnp.int32),
393-
target_ids=jnp.ones((1, 1), dtype=jnp.int32),
394-
target_mask=jnp.ones((1, 1), dtype=jnp.int32),
395-
position_ids=jnp.ones((1, 1), dtype=jnp.int32),
396-
decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32),
397-
deterministic=True,
398385
)
399386

400387
def no_op(self, *args, **kwargs):

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def __init__(
962962
# First LayerNorm, applied before the attention block.
963963
self.input_layernorm = Qwen3NextRMSNorm(
964964
num_features=cfg.emb_dim,
965-
eps=cfg.normalization_layer_epsilon,
965+
epsilon=cfg.normalization_layer_epsilon,
966966
dtype=cfg.dtype,
967967
weight_dtype=cfg.weight_dtype,
968968
rngs=rngs,
@@ -987,7 +987,7 @@ def __init__(
987987
# Second LayerNorm, applied before the MoE block.
988988
self.post_attention_layernorm = Qwen3NextRMSNorm(
989989
num_features=cfg.emb_dim,
990-
eps=cfg.normalization_layer_epsilon,
990+
epsilon=cfg.normalization_layer_epsilon,
991991
dtype=cfg.dtype,
992992
weight_dtype=cfg.weight_dtype,
993993
rngs=rngs,

0 commit comments

Comments
 (0)