Skip to content

Commit b732cb3

Browse files
author
Charles Li
committed
Fix unit test errors
1 parent 77d0c97 commit b732cb3

8 files changed

Lines changed: 208 additions & 444 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ subslice_shape: ""
10791079

10801080
# NNX
10811081
enable_nnx: false
1082-
pure_nnx_decoder: True
1082+
pure_nnx_decoder: false
10831083

10841084
################################## Qwen3-Next Specific Configs ##################################
10851085
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,14 @@ def __init__(
533533
elif self.is_qwen3_next:
534534
self.query_norm = Qwen3NextRMSNorm(
535535
num_features=self.config.head_dim,
536-
eps=self.config.normalization_layer_epsilon,
536+
epsilon=self.config.normalization_layer_epsilon,
537537
dtype=self.config.dtype,
538538
weight_dtype=self.config.weight_dtype,
539539
rngs=self.rngs,
540540
)
541541
self.key_norm = Qwen3NextRMSNorm(
542542
num_features=self.config.head_dim,
543-
eps=self.config.normalization_layer_epsilon,
543+
epsilon=self.config.normalization_layer_epsilon,
544544
dtype=self.config.dtype,
545545
weight_dtype=self.config.weight_dtype,
546546
rngs=self.rngs,

src/maxtext/layers/nnx_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs):
668668
)
669669
elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT:
670670
return functools.partial(
671-
normalizations.Qwen3NextRMSNorm, num_features=num_features, shard_mode=self.config.shard_mode
671+
normalizations.Qwen3NextRMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs
672672
)
673673
else:
674674
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")

src/maxtext/layers/normalizations.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
102102
return y_flat.reshape(input_shape)
103103

104104

105-
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
105+
def Qwen3NextRMSNorm(
106+
num_features: int,
107+
epsilon: float,
108+
dtype: DType,
109+
weight_dtype: DType,
110+
shard_mode: ShardMode = ShardMode.AUTO,
111+
kernel_axes: tuple[None | str, ...] = (),
112+
parameter_memory_host_offload: bool = False,
113+
*,
114+
rngs: nnx.Rngs,
115+
):
106116
"""
107117
Used for input and post attention layernorms
108118
in Qwen3NextDecoderLayer.
@@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
115125
return nnx.data(
116126
RMSNorm(
117127
num_features=num_features,
118-
epsilon=eps,
128+
epsilon=epsilon,
119129
dtype=dtype,
120130
weight_dtype=weight_dtype,
131+
shard_mode=shard_mode,
132+
kernel_axes=kernel_axes,
121133
scale_init=linen_initializers.zeros,
134+
parameter_memory_host_offload=parameter_memory_host_offload,
122135
scale_offset=1.0,
123136
rngs=rngs,
124137
)

src/maxtext/models/models.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN
2929
from maxtext.inference import page_manager
30-
from maxtext.layers.nnx_decoders import NNXDecoder, decoder_as_linen
30+
from maxtext.layers.nnx_decoders import NNXDecoder
3131
from maxtext.layers import initializers
3232
from maxtext.layers import nnx_wrappers
3333
from maxtext.layers.decoders import Decoder
@@ -88,12 +88,7 @@ def setup(self):
8888
)
8989
self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None
9090
self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None
91-
if cfg.pure_nnx_decoder:
92-
self.decoder = decoder_as_linen(
93-
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=nnx.Rngs(0)
94-
)
95-
else:
96-
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
91+
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
9792

9893
# If MTP is enabled via config, set up the MTP block.
9994
if self.config.mtp_num_layers > 0:

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def __init__(
962962
# First LayerNorm, applied before the attention block.
963963
self.input_layernorm = Qwen3NextRMSNorm(
964964
num_features=cfg.emb_dim,
965-
eps=cfg.normalization_layer_epsilon,
965+
epsilon=cfg.normalization_layer_epsilon,
966966
dtype=cfg.dtype,
967967
weight_dtype=cfg.weight_dtype,
968968
rngs=rngs,
@@ -987,7 +987,7 @@ def __init__(
987987
# Second LayerNorm, applied before the MoE block.
988988
self.post_attention_layernorm = Qwen3NextRMSNorm(
989989
num_features=cfg.emb_dim,
990-
eps=cfg.normalization_layer_epsilon,
990+
epsilon=cfg.normalization_layer_epsilon,
991991
dtype=cfg.dtype,
992992
weight_dtype=cfg.weight_dtype,
993993
rngs=rngs,

0 commit comments

Comments
 (0)