Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,8 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
enable_nnx: True
pure_nnx_decoder: True

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,14 +534,14 @@ def __init__(
elif self.is_qwen3_next:
self.query_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
)
self.key_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
out_sharding = metadata["sharding"]

if out_sharding is not None:
if nnx.PARTITION_NAME in metadata:
partition_name = metadata[nnx.PARTITION_NAME]
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0

sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
if partition_name not in sharding_list:
sharding_list.insert(scan_axis, partition_name)

out_sharding = tuple(sharding_list)

return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
variable.value,
out_sharding, # type: ignore[arg-type]
Expand Down
247 changes: 172 additions & 75 deletions src/maxtext/layers/nnx_decoders.py

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
return y_flat.reshape(input_shape)


def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
def Qwen3NextRMSNorm(
num_features: int,
epsilon: float,
dtype: DType,
weight_dtype: DType,
shard_mode: ShardMode = ShardMode.AUTO,
kernel_axes: tuple[None | str, ...] = (),
parameter_memory_host_offload: bool = False,
*,
rngs: nnx.Rngs,
):
"""
Used for input and post attention layernorms
in Qwen3NextDecoderLayer.
Expand All @@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
return nnx.data(
RMSNorm(
num_features=num_features,
epsilon=eps,
epsilon=epsilon,
dtype=dtype,
weight_dtype=weight_dtype,
shard_mode=shard_mode,
kernel_axes=kernel_axes,
scale_init=linen_initializers.zeros,
parameter_memory_host_offload=parameter_memory_host_offload,
scale_offset=1.0,
rngs=rngs,
)
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import calibration

from maxtext.layers import nnx_wrappers
import qwix
from qwix._src.core import dot_general_qt

Expand Down Expand Up @@ -285,7 +286,7 @@ class Fp8Quantization(Quantization):

def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.Fp8DirectDotGeneralOp
return nnx_wrappers.ToNNX(nn.Fp8DirectDotGeneralOp)

def einsum(self, dtype: DType = jnp.float32):
return _Fp8EinsumWrapper(dtype=dtype)
Expand Down
17 changes: 2 additions & 15 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from maxtext.layers.decoders import Decoder
from maxtext.layers.embeddings import Embed, embed_as_linen
from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen
from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen
from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.multimodal import processor as mm_processor
from maxtext.utils import max_utils
Expand Down Expand Up @@ -376,25 +376,12 @@ def __init__(
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
mtp_layer = layer_types[-1]
mtp_block_linen = multi_token_prediction_block_as_linen(
self.mtp_block = MultiTokenPredictionBlock(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer,
decoder=self.decoder,
rngs=rngs,
name="mtp_block",
)
self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs)

self.mtp_block.lazy_init(
shared_embedding=self.token_embedder,
main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype),
input_ids=jnp.ones((1, 1), dtype=jnp.int32),
target_ids=jnp.ones((1, 1), dtype=jnp.int32),
target_mask=jnp.ones((1, 1), dtype=jnp.int32),
position_ids=jnp.ones((1, 1), dtype=jnp.int32),
decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32),
deterministic=True,
)

def no_op(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def __init__(
# First LayerNorm, applied before the attention block.
self.input_layernorm = Qwen3NextRMSNorm(
num_features=cfg.emb_dim,
eps=cfg.normalization_layer_epsilon,
epsilon=cfg.normalization_layer_epsilon,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
rngs=rngs,
Expand All @@ -987,7 +987,7 @@ def __init__(
# Second LayerNorm, applied before the MoE block.
self.post_attention_layernorm = Qwen3NextRMSNorm(
num_features=cfg.emb_dim,
eps=cfg.normalization_layer_epsilon,
epsilon=cfg.normalization_layer_epsilon,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
rngs=rngs,
Expand Down
File renamed without changes.
Loading
Loading