From 743f0967ea5238f4653a12a09cb28a11cc10330d Mon Sep 17 00:00:00 2001 From: Param Bole Date: Thu, 14 May 2026 17:55:17 +0000 Subject: [PATCH] feat: implement DeepSeek-V4 model integration, decoders, and configuration stack Implement full model architecture, decoder integration layers, and execution configurations for DeepSeek-V4 integration into MaxText: - deepseek_v4.py: Model architecture definition supporting cyclical layer stacking and hyper-connections. - decoders.py & nnx_decoders.py: Integration of DeepSeekV4DecoderLayer, supporting get_attention_type routing and scanned vs unrolled compilation parity. - mhc.py & engram.py: Integration of multi-head hyper-connections (mHC) and engram memory management. - Configuration: Register model configs (deepseek_v4-flash.yml, deepseek_v4-tiny.yml) and hyperparameter definitions in base.yml and types.py. - Parity verification: Comprehensive unit test suite (deepseek_v4_vs_reference_test.py) validating end-to-end decoder block parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5. --- src/maxtext/configs/base.yml | 2 + .../configs/models/deepseek_v4-flash.yml | 72 ++ .../configs/models/deepseek_v4-tiny.yml | 72 ++ src/maxtext/configs/types.py | 7 + src/maxtext/layers/attention_compressed.py | 51 +- src/maxtext/layers/decoders.py | 108 +- src/maxtext/layers/engram.py | 9 + src/maxtext/layers/linears.py | 8 +- src/maxtext/layers/mhc.py | 92 +- src/maxtext/layers/moe.py | 45 +- src/maxtext/layers/nnx_decoders.py | 119 +- src/maxtext/layers/normalizations.py | 12 +- src/maxtext/models/deepseek_v4.py | 425 +++++++ tests/unit/deepseek_v4_vs_reference_test.py | 1129 ++++++++++++++--- tests/unit/nnx_decoders_test.py | 37 + 15 files changed, 1922 insertions(+), 266 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek_v4-flash.yml create mode 100644 src/maxtext/configs/models/deepseek_v4-tiny.yml create mode 100644 src/maxtext/models/deepseek_v4.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index a59b07b11f..d6fed399d3 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -268,6 +268,7 @@ topk_routing_group: -1 # number of top groups to route inputs. For EP, # all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers. use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits. batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True. +num_hash_layers: 3 # Number of initial MoE layers to apply static Hash Routing. # For complex architectures like llama4 there are repeated sets of # inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope] @@ -1227,6 +1228,7 @@ force_q_layout: false mhc_expansion_rate: 1 # The number of iterations for the Sinkhorn-Knopp algorithm. sinkhorn_iterations: 20 +hc_eps: 1.0e-6 ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. diff --git a/src/maxtext/configs/models/deepseek_v4-flash.yml b/src/maxtext/configs/models/deepseek_v4-flash.yml new file mode 100644 index 0000000000..5430e9d1b8 --- /dev/null +++ b/src/maxtext/configs/models/deepseek_v4-flash.yml @@ -0,0 +1,72 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Default model configs for DeepSeek-V4-Flash (43 Layers) + +base_config: base.yml +model_name: deepseek_v4-flash + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +head_dim: 512 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +base_num_decoder_layers: 43 +first_num_dense_layers: 0 +mlp_activations: ["silu"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 256 +num_experts_per_tok: 6 +shared_experts: 1 +routed_scaling_factor: 1.5 +routed_score_func: "sqrtsoftplus" +routed_bias: True +norm_topk_prob: True +decoder_block: "deepseek_v4" +pure_nnx_decoder: True +enable_nnx: True + +# Manifold-Constrained Hyper-Connection configurations +mhc_expansion_rate: 4 +sinkhorn_iterations: 20 +compress_rope_theta: 160000.0 +index_head_dim: 128 +index_n_heads: 64 +index_topk: 512 +o_groups: 8 +o_lora_rank: 1024 +sliding_window: 128 +num_hash_layers: 3 +mlp_activations_limit: 10.0 +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# Compressed Sparse Attention +q_lora_rank: 1024 +kv_lora_rank: 512 +qk_nope_head_dim: 128 +qk_rope_head_dim: 64 +v_head_dim: 128 +mscale: 1.0 + +# RoPE +rope_type: "default" +rope_max_timescale: 10_000 +max_position_embeddings: 1048576 +original_max_position_embeddings: 65536 +rope_factor: 16 +beta_fast: 32 diff --git a/src/maxtext/configs/models/deepseek_v4-tiny.yml b/src/maxtext/configs/models/deepseek_v4-tiny.yml new file mode 100644 index 0000000000..cc91244240 --- /dev/null +++ b/src/maxtext/configs/models/deepseek_v4-tiny.yml @@ -0,0 +1,72 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny version of DeepSeek-V4 (4 Layers) for local sharding and compilation checks. + +base_config: base.yml +model_name: deepseek_v4-tiny + +base_emb_dim: 128 +base_num_query_heads: 16 +base_num_kv_heads: 1 +head_dim: 32 +base_mlp_dim: 128 +base_moe_mlp_dim: 128 +base_num_decoder_layers: 6 +first_num_dense_layers: 0 +mlp_activations: ["silu"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 8 +num_experts_per_tok: 4 +shared_experts: 1 +routed_scaling_factor: 1.5 +routed_score_func: "sqrtsoftplus" +routed_bias: True +norm_topk_prob: True +decoder_block: "deepseek_v4" +pure_nnx_decoder: True +enable_nnx: True + +# Manifold-Constrained Hyper-Connection configurations +mhc_expansion_rate: 4 +sinkhorn_iterations: 20 +compress_rope_theta: 160000.0 +index_head_dim: 32 +index_n_heads: 16 +index_topk: 64 +o_groups: 2 +o_lora_rank: 64 +sliding_window: 32 +num_hash_layers: 3 +mlp_activations_limit: 10.0 +compress_ratios: [0, 4, 128, 4, 128, 0] + +# Compressed Attention +q_lora_rank: 64 +kv_lora_rank: 32 +qk_nope_head_dim: 32 +qk_rope_head_dim: 16 +v_head_dim: 128 +mscale: 1.0 + +# RoPE +rope_type: "default" +rope_max_timescale: 10_000 +max_position_embeddings: 163840 +original_max_position_embeddings: 4096 +rope_factor: 40 +beta_fast: 32 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e663bcfc5c..b27add7836 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -226,6 +226,8 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", + "deepseek_v4-tiny", + "deepseek_v4-flash", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -831,6 +833,10 @@ class DeepSeekMoE(BaseModel): 1, description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.", ) + num_hash_layers: int = Field( + 3, + description="Number of initial MoE layers to apply static Hash Routing.", + ) class Qwen3Next(BaseModel): @@ -1381,6 +1387,7 @@ class ManifoldConstrainedHyperConnections(BaseModel): mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.") sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.") + hc_eps: float = Field(1e-6, description="The epsilon fallback value for numerical stability in mHC.") class DilocoParams(BaseModel): diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index eba547d68f..a5f6918618 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -105,7 +105,7 @@ def __init__( # Interleaved rotary embeddings applied to the trailing slice self.rotary_emb = DeepSeekV4RotaryEmbedding( head_dim=head_dim, - partial_rotary_factor=64.0 / 512.0, + partial_rotary_factor=config.qk_rope_head_dim / config.head_dim, rope_theta=rope_theta, ) @@ -328,7 +328,7 @@ def __init__( # Interleaved rotary embedding aligning query/key pos representations self.rotary_emb = DeepSeekV4RotaryEmbedding( head_dim=self.head_dim, - partial_rotary_factor=(config.head_dim * (64.0 / 512.0)) / self.head_dim, + partial_rotary_factor=config.qk_rope_head_dim / self.head_dim, rope_theta=rope_theta, ) @@ -582,7 +582,7 @@ def __init__( # Interleaved rotary embeddings for compressed sequences self.rotary_emb = DeepSeekV4RotaryEmbedding( head_dim=head_dim, - partial_rotary_factor=64.0 / 512.0, + partial_rotary_factor=config.qk_rope_head_dim / config.head_dim, rope_theta=rope_theta, ) @@ -764,6 +764,7 @@ def __init__( eps: float = 1e-6, weight_dtype: Any = jnp.float32, dtype: Any = jnp.float32, + attention_type: str = "compressed_sparse_attention", *, rngs: nnx.Rngs, ): @@ -779,12 +780,13 @@ def __init__( eps: Tiny additive variance limit for RMS normalization stability. weight_dtype: The parameter weights numerical data type. dtype: The mathematical execution numerical data type. + attention_type: The type of compressed attention being instantiated. rngs: The Flax NNX random number generator collection. """ super().__init__() self.config = config self.layer_idx = layer_idx - self.layer_type = config.layer_types[layer_idx] + self.attention_type = attention_type self.num_heads = num_heads self.head_dim = head_dim self.sliding_window = config.sliding_window @@ -858,7 +860,7 @@ def __init__( self.sinks = nnx.Param(jax.nn.initializers.zeros(rngs.params(), (num_heads,), weight_dtype)) # Layer specific compressor allocation - if self.layer_type == "heavily_compressed_attention": + if self.attention_type == "heavily_compressed_attention": self.compressor = HCACompressor( hidden_size=hidden_size, head_dim=head_dim, @@ -869,7 +871,7 @@ def __init__( dtype=dtype, rngs=rngs, ) - elif self.layer_type == "compressed_sparse_attention": + elif self.attention_type == "compressed_sparse_attention": self.compressor = CSACompressor( hidden_size=hidden_size, q_lora_rank=q_lora_rank, @@ -884,14 +886,33 @@ def __init__( else: self.compressor = None + # Compute partial rotary factor dynamically from config to prevent dimension mismatches. + # DeepSeek-V4 pairs consecutive channels to apply partial RoPE on qk_rope_head_dim channels, + # requiring dynamic scaling: partial_rotary_factor = qk_rope_head_dim / head_dim. + self.partial_rotary_factor = self.config.qk_rope_head_dim / self.config.head_dim + + self.rope_theta = ( + self.config.rope_max_timescale if self.attention_type == "sliding_attention" else self.config.compress_rope_theta + ) + + # Local rotary embedding block matching standard MaxText (Gemma/Llama2) paradigms. + self.rotary_embedding = DeepSeekV4RotaryEmbedding( + head_dim=self.head_dim, + partial_rotary_factor=self.partial_rotary_factor, + rope_theta=self.rope_theta, + ) + def __call__( self, - hidden_states: jnp.ndarray, - cos: jnp.ndarray, - sin: jnp.ndarray, - position_ids: jnp.ndarray, + hidden_states: jnp.ndarray | None = None, + position_ids: jnp.ndarray | None = None, attention_mask: jnp.ndarray | None = None, + inputs_q: jnp.ndarray | None = None, + inputs_kv: jnp.ndarray | None = None, + **kwargs, ) -> tuple[jnp.ndarray, jnp.ndarray]: + if hidden_states is None: + hidden_states = inputs_q """Executes DeepSeek-V4 compressed multi-head attention. This method projects input states to query representations, applies low-rank @@ -903,8 +924,6 @@ def __call__( Args: hidden_states: The input hidden representation sequence of shape [B, S, D_model]. - cos: Positional RoPE cosine frequencies array of shape [B, S, D_rope]. - sin: Positional RoPE sine frequencies array of shape [B, S, D_rope]. position_ids: Absolute sequence position identifiers of shape [B, S]. attention_mask: Optional attention mask of shape [B, 1, S, S_kv]. @@ -914,10 +933,16 @@ def __call__( - The final multi-head attention weights of shape [B, H, S, S_kv]. """ # hidden_states shape: [B, S, D_model] - # cos, sin shape: [B, S, D_rope] # position_ids shape: [B, S] # attention_mask shape: [B, 1, S_q, S_kv] batch, seq_len, _ = hidden_states.shape + # Unconditionally compute RoPE positional frequency embeddings locally from position IDs. + if position_ids is None: + # [B, S] position sequence index grid broadcast + position_ids = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32)[None], (batch, seq_len)) + # cos/sin shape: [B, S, qk_rope_head_dim / 2] + cos, sin = self.rotary_embedding(hidden_states, position_ids) + h_shape = (batch, seq_len, self.num_heads, self.head_dim) # Project inputs to query representations diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..22ba104255 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -22,6 +22,8 @@ from flax import linen as nn from flax import nnx +from maxtext.layers import nnx_wrappers +from maxtext.models.deepseek_v4 import DeepSeekV4HyperHead from flax.linen.partitioning import ScanIn import jax from jax.ad_checkpoint import checkpoint_name @@ -43,6 +45,7 @@ deepseek, deepseek_batchsplit, deepseek_batchsplit_fp8, + deepseek_v4, gemma, gemma2, gemma3, @@ -460,6 +463,12 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK_V4: + return ( + [deepseek_v4.DeepSeekV4ScannableBlockToLinen] + if self.config.scan_layers + else [deepseek_v4.DeepSeekV4DecoderLayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -983,6 +992,20 @@ def __call__( page_state, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4: + bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None + y = self._apply_deepseek_v4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask=bidirectional_mask_value, + decoder_input_tokens=decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1089,6 +1112,16 @@ def __call__( RemattedBlockLayer = RemattedBlockLayers[0] layer_kwargs = {} layer_call_kwargs = {} + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4: + # Retrieve layer-specific compression ratio from configuration to support sliding window attention + # at boundary layers and alternating compressed sparse/heavily compressed attention. + compress_ratio = self.config.compress_ratios[lyr] + bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None + layer_kwargs = {"compress_ratio": compress_ratio, "layer_idx": lyr} + layer_call_kwargs = { + "decoder_input_tokens": decoder_input_tokens, + "bidirectional_mask": bidirectional_mask_value, + } if cfg.decoder_block == DecoderBlockType.GEMMA3: # Gemma3 uses both global and sliding window attention depending on the layer index. bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None @@ -1151,7 +1184,11 @@ def __call__( assert isinstance(y, jax.Array) # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - if cfg.mhc_expansion_rate > 1: + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4: + # Collapse final streams using learnable collapse weights [B, S, k, D] -> [B, S, D] + hc_head = nnx_wrappers.to_linen_class(DeepSeekV4HyperHead, name="hc_head")(config=cfg) + hidden_state = hc_head(y) + elif cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) hidden_state = mhc_reduce(y) else: @@ -1335,6 +1372,75 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek_v4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask=None, + decoder_input_tokens=None, + ): + """Applies DeepSeek-V4 scanned decoder blocks under Flax Linen, handling main scan and remainders.""" + cfg = self.config + mesh = self.mesh + + # Define the repeating pattern length (2 for cyclical DeepSeek-V4 layers) + scan_length = cfg.num_decoder_layers // 2 + + policy = self.get_remat_policy() + RemattedDSV4Block = self.set_remat_policy([deepseek_v4.DeepSeekV4ScannableBlockToLinen], policy)[0] + + layer_call_kwargs = { + "decoder_input_tokens": decoder_input_tokens, + "bidirectional_mask": bidirectional_mask, + } + layer_kwargs = {"num_of_layers": 2} + + # Apply the main scan over the full blocks + if scan_length > 0: + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + # inputs: y shape [B, S, k, D] -> [B, S, k, D] + y, _ = self.scan_decoder_layers( + cfg, + RemattedDSV4Block, + scan_length, + "layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=self.model_mode, + **layer_kwargs, + )(y, *broadcast_args, **layer_call_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % 2 + if num_remaining_layers > 0: + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + layer = RemattedDSV4Block( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs + ) + y, _ = layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + **layer_call_kwargs, + ) + return y + # TODO(b/490118813): Relocate the following functions to their designated directories # once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer() # _apply_scanned_chunk() and _apply_interleaved_scanned_layers(). diff --git a/src/maxtext/layers/engram.py b/src/maxtext/layers/engram.py index 3b2eb4e2b5..f457097d72 100644 --- a/src/maxtext/layers/engram.py +++ b/src/maxtext/layers/engram.py @@ -335,6 +335,15 @@ class StaticWrapper: def __init__(self, val): self.val = val + def __getitem__(self, key): + return self.val[key] + + def __setitem__(self, key, value): + if key is Ellipsis: + self.val = value + else: + self.val = self.val.at[key].set(value) + class MultiHeadEmbedding(nnx.Module): """ diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index 69cf7baf9f..519f35a5bc 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -603,7 +603,7 @@ def __init__( # Grouped block-diagonal projection kernel parameters # Kernels are stored as a 3D tensor: [n_groups, in_features_per_group, out_features_per_group] kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group) - self.weight = nnx.Param( + self.kernel = nnx.Param( kernel_init( rngs.params(), kernel_shape, @@ -623,10 +623,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: Projected tensor of shape [..., n_groups, out_features_per_group] """ x = jnp.asarray(x, self.dtype) - weight = jnp.asarray(self.weight[...], self.dtype) + kernel = jnp.asarray(self.kernel[...], self.dtype) # Execute parallel group projection via optimized einsum broadcasting. # x: [..., g, i] - # weight: [g, i, o] + # kernel: [g, i, o] # output: [..., g, o] - return jnp.einsum("...gi,gio->...go", x, weight) + return jnp.einsum("...gi,gio->...go", x, kernel) diff --git a/src/maxtext/layers/mhc.py b/src/maxtext/layers/mhc.py index ce700aafcd..c479ebd177 100644 --- a/src/maxtext/layers/mhc.py +++ b/src/maxtext/layers/mhc.py @@ -19,10 +19,10 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from maxtext.common.common_types import Array, Config +from maxtext.common.common_types import Array, Config, DecoderBlockType from maxtext.common.common_types import HyperConnectionType from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init -from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.normalizations import DeepSeekV4UnweightedRMSNorm, RMSNorm def get_functions(expansion_rate: int): @@ -42,26 +42,26 @@ def reduce(x: Array): return expand, reduce -def sinkhorn(t, iters=20): +def sinkhorn(t, iters=20, eps=1e-12): """Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).""" - # Use float32 precision for numerical stability during normalization + # Use float32 precision for numerical stability during alternating L1 row/column normalizations. + # val: [B, S, H, H] initial_dtype = t.dtype t = t.astype(jnp.float32) - # Column-wise normalization (axis=-2) - positive and sum up to 1 across columns - # Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2) - t = jax.nn.softmax(t, axis=-2) + # Column normalization first (sum along axis -2) matching Xie et al. Equation 8 initialization. + t = t / (jnp.sum(t, axis=-2, keepdims=True) + eps) def body_fun(i, val): - # L1 Normalization: val / sum(val) with clipping of denominator + # L1 Normalization: val / (sum(val) + eps) matching the exact denominator addition. # Normalize rows (axis -1) - val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12) + val = val / (jnp.sum(val, axis=-1, keepdims=True) + eps) # Normalize columns (axis -2) - val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12) + val = val / (jnp.sum(val, axis=-2, keepdims=True) + eps) return val - # Use lax.fori_loop for an efficient, JIT-friendly loop - t = jax.lax.fori_loop(0, iters, body_fun, t) + # Use lax.fori_loop for an efficient, JIT-friendly loop over exactly iters - 1 steps. + t = jax.lax.fori_loop(0, iters - 1, body_fun, t) return t.astype(initial_dtype) @@ -95,14 +95,20 @@ def __init__( self.matmul_precision = jax.lax.Precision(self.config.matmul_precision) # Norm layer - self.mhc_norm = RMSNorm( - num_features=self.k * self.dim, - dtype=self.config.dtype, - weight_dtype=self.weight_dtype, - kernel_axes=("norm",), - epsilon=self.config.normalization_layer_epsilon, - rngs=self.rngs, - ) + if getattr(self.config, "decoder_block", None) == DecoderBlockType.DEEPSEEK_V4: + self.mhc_norm = DeepSeekV4UnweightedRMSNorm( + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + ) + else: + self.mhc_norm = RMSNorm( + num_features=self.k * self.dim, + dtype=self.config.dtype, + weight_dtype=self.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=self.rngs, + ) # Scalars self.res_alpha_scale = nnx.Param( @@ -170,28 +176,33 @@ def __init__( def res_mapping(self, x: Array): """Helper function for residual mapping.""" - # In MaxText, we match weight precision to activations before Matmul + # In MaxText, we match weight precision to activations before Matmul. + # x: [B, S, H * D] representing sequence token features. + # res_alpha: [H * D, H * H] + # res_beta: [H, H] res_alpha = jnp.asarray(self.res_alpha[...], self.dtype) res_beta = jnp.asarray(self.res_beta[...], self.dtype) res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) - # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) b, s, _ = h_res.shape h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] - output = sinkhorn(intermediate, self.sinkhorn_iterations) + # Apply softmax pre-normalization along the trailing axis matching the exact initialization. + # intermediate: [B, S, H, H] + intermediate = jax.nn.softmax(intermediate, axis=-1) + self.config.hc_eps + output = sinkhorn(intermediate, self.sinkhorn_iterations, eps=self.config.hc_eps) return output - def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): + def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: float, eps: float = 0.0): """Helper function for both pre and post mappings.""" # In MaxText, we match weight precision to activations before Matmul alpha = jnp.asarray(alpha, self.dtype) beta = jnp.asarray(beta, self.dtype) alpha_scale = jnp.asarray(alpha_scale, self.dtype) - # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) - h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision) + # Apply projection: (b, s, e*d) @ (e*d, e) -> (b, s, e) + h = jnp.einsum("bsm,me -> bse", x, alpha, precision=self.matmul_precision) intermediate = alpha_scale * h + beta[None, None, :] - output = scale * jax.nn.sigmoid(intermediate) + output = scale * jax.nn.sigmoid(intermediate) + eps return output def __call__( @@ -227,8 +238,12 @@ def __call__( self.pre_alpha[...], self.pre_beta[...], 1.0, + self.config.hc_eps, ) - layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision) + layer_input = jnp.einsum( + "bsed,bse -> bsd", x.astype(jnp.float32), pre_mapping.astype(jnp.float32), precision=self.matmul_precision + ) + layer_input = layer_input.astype(self.dtype) # 3. Pre-norm layer_input = norm_fn(layer_input) @@ -246,22 +261,31 @@ def __call__( else: raise ValueError(f"Unsupported type: {mhc_type}") - # 5. Post mapping + # 5. Post mapping (multiplied by 2.0 matching post_scale) post_mapping = self.mapping( norm_x, self.post_alpha_scale[...], self.post_alpha[...], self.post_beta[...], 2.0, + 0.0, ) post_out = jnp.einsum( - "bsd,bsk -> bskd", - layer_out, - post_mapping, + "bsd,bse -> bsed", + layer_out.astype(jnp.float32), + post_mapping.astype(jnp.float32), precision=self.matmul_precision, ) # 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] res_mapping = self.res_mapping(norm_x) - res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision) - return res_out + post_out, metadata + # Transposed residual mixing (bsme @ bsmd -> bsed) matching Xie et al. Equation 8 + # representing the projection index: comb.T @ residual stream values. + # res_mapping: [B, S, H_src, H_dest] + # x: [B, S, H_src, D] + # res_out: [B, S, H_dest, D] + res_out = jnp.einsum( + "bsme,bsmd -> bsed", res_mapping.astype(jnp.float32), x.astype(jnp.float32), precision=self.matmul_precision + ) + output = res_out + post_out + return output.astype(self.dtype), metadata diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 18c55b6571..552541e76b 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -38,6 +38,7 @@ from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical, maybe_shard_with_pspec from maxtext.utils.sharding import logical_to_mesh_axes +from maxtext.layers.engram import StaticWrapper import numpy as np import qwix from qwix.contrib.sparsity import sparsity_module @@ -368,9 +369,14 @@ def __call__(self, hidden_states: jax.Array) -> Tuple[jax.Array, jax.Array, jax. kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32) logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32) - # Apply custom scoring function (sqrtsoftplus). + # Apply routed scoring function from configuration. # [tokens, num_experts] -> [tokens, num_experts] - scores = _sqrtsoftplus(logits) + score_fn = ( + _sqrtsoftplus + if self.config.routed_score_func == "sqrtsoftplus" + else linears._convert_to_activation_function(self.config.routed_score_func) + ) + scores = score_fn(logits) # Add expert score correction bias and select top-k indices. # [tokens, num_experts] + [num_experts] -> [tokens, num_experts] @@ -437,7 +443,9 @@ def __init__( # Static token-to-expert mapping table. # Shape: [vocab_size, top_k] - self.tid2eid = nnx.Param( + # Using StaticWrapper isolates the non-differentiable static lookup indices from the + # autograd gradient tracking tree, ensuring stable backward pass compiles. + self.tid2eid = StaticWrapper( jnp.zeros((config.vocab_size, self.top_k), dtype=jnp.int32), ) @@ -452,15 +460,21 @@ def __call__(self, hidden_states: jax.Array, input_ids: jax.Array) -> Tuple[jax. kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32) logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32) - # Apply custom scoring function (sqrtsoftplus). + # Apply routed scoring function from configuration. # [tokens, num_experts] -> [tokens, num_experts] - scores = _sqrtsoftplus(logits) + score_fn = ( + _sqrtsoftplus + if self.config.routed_score_func == "sqrtsoftplus" + else linears._convert_to_activation_function(self.config.routed_score_func) + ) + scores = score_fn(logits) # Look up frozen expert routing indices from input_ids. # [batch, seq_len] -> [tokens] flat_input_ids = input_ids.reshape(-1) + # Look up from StaticWrapper to retrieve frozen lookup indices. # [vocab_size, top_k] sliced at [tokens] -> [tokens, top_k] - indices = self.tid2eid[...][flat_input_ids] + indices = self.tid2eid.val[flat_input_ids] # Gather corresponding scores for the statically selected expert indices. # [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k] @@ -558,8 +572,8 @@ def __init__( else: self._expert_parallelism_name = "expert" - self.is_hash = self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < getattr( - config, "num_hash_layers", 3 + self.is_hash = ( + self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < config.num_hash_layers ) if self.is_hash: self.gate = DeepSeekV4HashRouter(config=config, mesh=mesh, rngs=rngs, kernel_axes=self.kernel_axes) @@ -1403,7 +1417,10 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) - if self.config.model_name.startswith("deepseek3"): + if ( + self.config.model_name.startswith("deepseek3") + or self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 + ): pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models @@ -1797,7 +1814,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): input_axes = (batch_logical_axis, "activation_norm_length", None) gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) - if self.config.model_name.startswith("deepseek3"): + if self.config.model_name.startswith("deepseek3") or self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) else: pre_bias_logits_axes = None @@ -2496,11 +2513,13 @@ def __call__( if self.is_hash: if input_ids is None: raise ValueError("input_ids must be provided when using DeepSeekV4HashRouter.") - gate_logits, pre_bias_logits, _ = self.gate(routing_inputs, input_ids) + gate_logits, gate_weights_val, gate_indices_val = self.gate(routing_inputs, input_ids) else: - gate_logits, pre_bias_logits, _ = self.gate(routing_inputs) + gate_logits, gate_weights_val, gate_indices_val = self.gate(routing_inputs) gate_logits = gate_logits.reshape(batch_size, seq_len, -1) - pre_bias_logits = pre_bias_logits.reshape(batch_size, seq_len, -1) + gate_weights = gate_weights_val.reshape(batch_size, seq_len, -1) + gate_indices = gate_indices_val.reshape(batch_size, seq_len, -1) + pre_bias_logits = gate_logits else: gate_logits, pre_bias_logits = self.gate(routing_inputs) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 262eb62277..ec04179489 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -19,7 +19,7 @@ import functools import inspect import warnings -from typing import Any +from typing import Any, Optional import jax import jax.numpy as jnp @@ -48,6 +48,7 @@ deepseek, deepseek_batchsplit, deepseek_batchsplit_fp8, + deepseek_v4, gemma, gemma2, gemma3, @@ -63,6 +64,7 @@ qwen3_5, simple_layer, ) +from maxtext.models.deepseek_v4 import DeepSeekV4HyperHead from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis @@ -299,9 +301,13 @@ def __init__( self.scanned_layers = None self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK + self.is_deepseek_v4 = self.config.decoder_block == DecoderBlockType.DEEPSEEK_V4 self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 self.is_gemma4 = self.config.decoder_block == DecoderBlockType.GEMMA4 + if self.is_deepseek_v4: + self.hc_head = DeepSeekV4HyperHead(config, rngs=rngs) + if self.config.scan_layers: if self.is_deepseek: assert len(decoder_block_classes) == 2 @@ -389,6 +395,23 @@ def __init__( self.layers_remainder = RemattedGemma4Block( config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs ) + elif self.is_deepseek_v4: + scan_length = config.num_decoder_layers // 2 + num_remaining_layers = config.num_decoder_layers % 2 + layer_kwargs = {"num_of_layers": 2} + + rem_layer_kwargs = {"num_of_layers": num_remaining_layers, "layer_offset": scan_length * 2} + + RemattedDeepSeekV4Block = deepseek_v4.DeepSeekV4ScannableBlock + + if scan_length > 0: + self.layers = self._create_scanned_layers( + RemattedDeepSeekV4Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + if num_remaining_layers > 0: + self.layers_remainder = RemattedDeepSeekV4Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) else: layer_cls = decoder_block_classes[0] num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) @@ -435,6 +458,11 @@ def __init__( layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} elif config.decoder_block == DecoderBlockType.OLMO3: layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.DEEPSEEK_V4: + # Retrieve layer-specific compression ratio from configuration to support sliding window attention + # at boundary layers and alternating compressed sparse/heavily compressed attention. + compress_ratio = self.config.compress_ratios[lyr] + layer_kwargs = {"compress_ratio": compress_ratio, "layer_idx": lyr} self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) @@ -713,6 +741,9 @@ def get_deepseek(): DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], DecoderBlockType.DEEPSEEK: get_deepseek(), + DecoderBlockType.DEEPSEEK_V4: get_scannable( + deepseek_v4.DeepSeekV4DecoderLayer, deepseek_v4.DeepSeekV4ScannableBlock + ), DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), DecoderBlockType.QWEN3_5: get_scannable(qwen3_5.Qwen3_5DecoderLayer, qwen3_5.Qwen3_5ScannableBlock), @@ -863,6 +894,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, DecoderBlockType.OLMO3, + DecoderBlockType.DEEPSEEK_V4, ): return functools.partial(RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) elif self.config.decoder_block == DecoderBlockType.GPT3: @@ -1118,7 +1150,7 @@ def __call__( # Extract the bidirectional mask locally for layer configurations bidirectional_mask = multimodal_input.bidirectional_mask if multimodal_input is not None else None - if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): + if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4, DecoderBlockType.DEEPSEEK_V4): layer_kwargs["bidirectional_mask"] = bidirectional_mask if attention_metadata is not None: @@ -1212,6 +1244,19 @@ def __call__( page_state, slot, ) + elif self.is_deepseek_v4: + y = self._apply_deepseek_v4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask=bidirectional_mask, + decoder_input_tokens=decoder_input_tokens, + ) else: scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) if kv_caches is not None: @@ -1230,13 +1275,16 @@ def __call__( prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) # Hoisted function to preserve XLA cache ID - def pure_layer_fn(graphdef, state_in, y_in, kv_in): + def pure_layer_fn(graphdef, state_in, y_in, kv_in, decoder_input_tokens_in=None): if cfg.parameter_memory_host_offload: state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in) merged_layer = nnx.merge(graphdef, state_in) - out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) + call_kwargs = dict(layer_kwargs) + if decoder_input_tokens_in is not None: + call_kwargs["decoder_input_tokens"] = decoder_input_tokens_in + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **call_kwargs) return out_y, out_kv, nnx.state(merged_layer) checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) @@ -1254,11 +1302,11 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): else: kv_cache = None - input_tokens = decoder_input_tokens if cfg.engram_layers else None - if input_tokens is not None: - layer_kwargs["decoder_input_tokens"] = input_tokens - - y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) + input_tokens = ( + decoder_input_tokens if (cfg.engram_layers or cfg.decoder_block == DecoderBlockType.DEEPSEEK_V4) else None + ) + # Propagation of decoder_input_tokens of shape [B, S] alongside hidden state y of shape [B, S, k, D] + y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache, input_tokens) nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: @@ -1277,7 +1325,10 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): assert isinstance(y, jax.Array) # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - if cfg.mhc_expansion_rate > 1: + if self.is_deepseek_v4: + # collapsed shape: [B, S, k, D] -> [B, S, D] via learnable collapse weights + hidden_state = self.hc_head(y) + elif cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) hidden_state = mhc_reduce(y) else: @@ -1401,6 +1452,54 @@ def pure_gemma_fn(graphdef, state_in, y_in): return y + def _apply_deepseek_v4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + bidirectional_mask: Optional[jax.Array] = None, + decoder_input_tokens: Optional[jax.Array] = None, + ): + """Applies DeepSeek-V4 scanned decoder blocks, handling main scan and remainders.""" + cfg = self.config + scan_length = cfg.num_decoder_layers // 2 + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + layer_kwargs = { + "decoder_input_tokens": decoder_input_tokens, + "bidirectional_mask": bidirectional_mask, + } + + # Apply the main scan over the full blocks + if scan_length > 0: + y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % 2 + if num_remaining_layers > 0: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + def pure_deepseek_fn(graphdef, state_in, y_in): + merged_layer = nnx.merge(graphdef, state_in) + out_y, _ = merged_layer( + y_in, *layer_args, previous_chunk=previous_chunk, page_state=page_state, slot=slot, **layer_kwargs + ) + return out_y, nnx.state(merged_layer) + + checkpointed_deepseek_fn = jax.checkpoint(pure_deepseek_fn, policy=policy, prevent_cse=prevent_cse) + + graphdef, state = nnx.split(self.layers_remainder) + y, new_state = checkpointed_deepseek_fn(graphdef, state, y) + nnx.update(self.layers_remainder, new_state) + + return y + def decoder_as_linen( config: Config, diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index c334342c61..a2697353df 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -243,7 +243,7 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array: class DeepSeekV4RMSNorm(nnx.Module): - """RMS normalization for DeepSeek-V4 (equivalent to T5LayerNorm).""" + """RMS normalization for DeepSeek-V4.""" def __init__( self, @@ -257,8 +257,9 @@ def __init__( self.dtype = dtype self.weight_dtype = weight_dtype - # Initialize learnable scale weight to ones matching T5LayerNorm behavior + # Initialize learnable scale weight to ones self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype)) + self.scale = self.weight def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # [B, S, D] where D = hidden_size @@ -296,6 +297,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # Calculate variance across features axis variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [..., 1] - # Apply reciprocal square root and cast back to active precision - normalized = x_f32 * lax.rsqrt(variance + self.eps) # [..., D] - return jnp.asarray(normalized, self.dtype) # [..., D] + # Apply reciprocal square root, cast to active precision, and multiply + inv_norm = jnp.asarray(lax.rsqrt(variance + self.eps), self.dtype) # [..., 1] + x_active = jnp.asarray(x, self.dtype) # [..., D] + return x_active * inv_norm # [..., D] diff --git a/src/maxtext/models/deepseek_v4.py b/src/maxtext/models/deepseek_v4.py new file mode 100644 index 0000000000..70f81c5da1 --- /dev/null +++ b/src/maxtext/models/deepseek_v4.py @@ -0,0 +1,425 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decoder Layer and Scannable Block definitions for DeepSeek-V4. + +DeepSeek-V4 Decoder Layer Data Flow Guide: +`B` = batch_size, `S` = sequence_length, `k` = hc_mult (expansion rate), `D` = hidden_size + + Parallel Streams Input [B, S, k, D] + │ + ├───► [mHC Pre-norm & Mapping] ──► [B, S, k * D] ──► Flat-Norm + │ │ + │ [pre_alpha / pre_beta] + │ │ + │ ▼ + │ Sigmoid Logits + │ │ + │ ▼ + │ "pre" weights [B, S, k] + │ │ + ├─────────────────────────────────────────────────────────┼────────────┐ + ▼ ▼ │ + Parallel Streams [B, S, k, D] Collapse Sum │ + │ │ │ + [mHC Res-Mapping] ▼ │ + │ Collapsed [B, S, D] │ + [res_alpha / res_beta] │ │ + │ RMSNorm Pre-Attn │ + ▼ │ │ + Sigmoid Logits [B, S, k, k] ▼ │ + │ DeepSeekV4Attention │ + Sinkhorn-Knopp │ │ + │ ▼ │ + Doubly Stochastic "comb" Attn Output [B, S, D] │ + │ │ │ + ▼ [mHC Post-Mapping] │ + Multiply [post_alpha / beta] │ + │ │ │ + ▼ ▼ │ + Mixed Residual Sigmoid Logits │ + [B, S, k, D] │ │ + │ ▼ │ + │ "post" weights [B, S, k]│ + │ │ │ + │ ▼ │ + │ Expanded Output │ + │ [B, S, k, D] │ + │ │ │ + └───────────────────────► ( + ) ◄─────────────────────────┘ │ + │ │ + ▼ │ + Attention Site Output │ + [B, S, k, D] │ + │ │ + ▼ │ + Experts MoE FFN Site │ + (Same flow: Collapse -> MoE -> Expand) │ + │ │ + ▼ │ + Layer Output [B, S, k, D] ◄──────────────────────────┘ +""" + +from typing import Any, Optional +from flax import nnx +import jax +from jax.ad_checkpoint import checkpoint_name +import jax.numpy as jnp +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.layers import initializers +from maxtext.layers import mhc +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.layers.attention_compressed import DeepSeekV4Attention +from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm +from maxtext.utils import max_utils +from maxtext.utils.sharding import create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical + + +def get_attention_type(compress_ratio: int) -> str: + """Returns the attention type string corresponding to the given compression ratio.""" + if compress_ratio == 0: + return "sliding_attention" + elif compress_ratio == 4: + return "compressed_sparse_attention" + else: + return "heavily_compressed_attention" + + +class DeepSeekV4DecoderLayer(nnx.Module): + """Transformer decoder layer for DeepSeek-V4. + + This layer unconditionally implements routed and shared MoE and unconditionally + applies Manifold-Constrained Hyper-Connections (mHC) to both attention and FFN block outputs. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + compress_ratio: int = 4, + layer_idx: int = 0, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + self.compress_ratio = compress_ratio + self.layer_idx = layer_idx + + batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode) + self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim) + + # Pre-attention normalization layer + self.pre_self_attention_layer_norm = DeepSeekV4RMSNorm( + hidden_size=self.config.emb_dim, + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + ) + + # Compressed multi-head attention module. + num_heads = ( + self.config.num_query_heads if self.config.num_query_heads is not None else self.config.num_attention_heads + ) + attention_type = get_attention_type(self.compress_ratio) + + self.self_attention = DeepSeekV4Attention( + hidden_size=self.config.emb_dim, + q_lora_rank=self.config.q_lora_rank, + head_dim=self.config.head_dim, + num_heads=num_heads, + config=config, + layer_idx=layer_idx, + eps=self.config.normalization_layer_epsilon, + weight_dtype=self.config.weight_dtype, + dtype=self.config.dtype, + attention_type=attention_type, + rngs=self.rngs, + ) + + # Manifold-constrained hyper-connection wrapper for attention block outputs. + self.mhc_attention = mhc.ManifoldConstrainedHyperConnections( + config=self.config, + dim=self.config.emb_dim, + mesh=self.mesh, + rngs=self.rngs, + ) + + # Pre-FFN normalization layer + self.post_self_attention_layer_norm = DeepSeekV4RMSNorm( + hidden_size=self.config.emb_dim, + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + ) + + # Routed sparse and shared experts mixture-of-experts FFN module. + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed_moe", None), + weight_dtype=self.config.weight_dtype, + dtype=self.config.dtype, + quant=self.quant, + rngs=self.rngs, + layer_idx=self.layer_idx, + ) + + # Manifold-constrained hyper-connection wrapper for FFN block outputs. + self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections( + config=self.config, + dim=self.config.emb_dim, + mesh=self.mesh, + rngs=self.rngs, + ) + + self.out_sharding = create_sharding(self.mesh, self.logical_axis_names, rules=self.config.logical_axis_rules) + + @property + def logical_axis_names(self): + """Generate logical names for activations dynamically decoupling length dimensions.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + return ["activation_batch", length_name, "activation_embed"] + + def with_logical_constraint(self, x): + """Applies sharding constraints over logical axes.""" + return maybe_shard_with_logical( + x, + logical_axes=self.logical_axis_names, + mesh=self.mesh, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + extra_stack_level=1, + rules=self.config.logical_axis_rules, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: Optional[jnp.ndarray] = None, + decoder_positions: Optional[jnp.ndarray] = None, + deterministic: bool = True, + model_mode: str = "train", + previous_chunk: Optional[jnp.ndarray] = None, + page_state: Any = None, + slot: Any = None, + bidirectional_mask: Optional[jnp.ndarray] = None, + kv_cache: Any = None, + attention_metadata: Any = None, + cos: Optional[jnp.ndarray] = None, + sin: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_input_tokens: Optional[jnp.ndarray] = None, + ): + # inputs shape: [B, S, k, D] (where B = batch, S = sequence length, k = expansion rate, D = hidden dim) + if isinstance(inputs, tuple): + inputs = inputs[0] + + if decoder_positions is None and position_ids is not None: + decoder_positions = position_ids + if decoder_segment_ids is None: + decoder_segment_ids = jnp.zeros(inputs.shape[:2], dtype=jnp.int32) + + # Apply constraint to inputs: [B, S, k, D] -> [B, S, k, D] + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + # 1. Attention hyper-connection block + # intermediate_inputs: [B, S, k, D] -> [B, S, k, D] + intermediate_inputs, _ = self.mhc_attention( + norm_fn=self.pre_self_attention_layer_norm, + branch_fn=self.self_attention, + x=x, + mhc_type=HyperConnectionType.ATTENTION, + attention_mask=bidirectional_mask, + cos=cos, + sin=sin, + position_ids=decoder_positions, + ) + + # 2. Experts MoE FFN hyper-connection block + # Inputs: intermediate_inputs: [B, S, k, D], decoder_input_tokens (input_ids): [B, S] + # Outputs output: [B, S, k, D] + output, metadata = self.mhc_mlp( + norm_fn=self.post_self_attention_layer_norm, + branch_fn=self.mlp, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + input_ids=decoder_input_tokens, + ) + + load_balance_loss = metadata["load_balance_loss"] + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow("intermediates", "moe_lb_loss", load_balance_loss) + + # Final output constraint application: [B, S, k, D] -> [B, S, k, D] + output = self.with_logical_constraint(output) + + if self.config.scan_layers: + return output, None + else: + return output, kv_cache + + +DeepSeekV4DecoderLayerToLinen = nnx_wrappers.to_linen_class( + DeepSeekV4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekV4ScannableBlock(nnx.Module): + """A repeating cyclical block of DeepSeek-V4 decoder layers for compiler scan loops.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + num_of_layers: int = 2, + layer_offset: int = 0, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + self.num_of_layers = num_of_layers + self.layer_offset = layer_offset + + for layer_id in range(self.num_of_layers): + abs_layer_id = self.layer_offset + layer_id + # Retrieve layer-specific compression ratio from configuration to support sliding window attention + # at boundary layers and alternating compressed sparse/heavily compressed attention. + compress_ratio = self.config.compress_ratios[abs_layer_id] + layer_name = f"layers_{layer_id}" + layer = DeepSeekV4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=compress_ratio, + layer_idx=abs_layer_id, + ) + setattr(self, layer_name, layer) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: jnp.ndarray, + decoder_positions: jnp.ndarray, + deterministic: bool, + model_mode: str, + slot: Any = None, + page_state: Any = None, + previous_chunk: Optional[jnp.ndarray] = None, + bidirectional_mask: Optional[jnp.ndarray] = None, + decoder_input_tokens: Optional[jnp.ndarray] = None, + ): + y = inputs + for layer_id in range(self.num_of_layers): + y, _ = getattr(self, f"layers_{layer_id}")( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + bidirectional_mask=bidirectional_mask, + decoder_input_tokens=decoder_input_tokens, + ) + return y, None + + +DeepSeekV4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeekV4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekV4HyperHead(nnx.Module): + """Final learnable Manifold-Constrained Hyper-Connection (mHC) collapse head. + + This head collapses the parallel streams [B, S, k, D] down to a single + sequence [B, S, D] before applying the final RMSNorm. + """ + + def __init__(self, config: Config, rngs: nnx.Rngs): + self.config = config + self.hc_mult = getattr(config, "mhc_expansion_rate", 4) + self.eps = getattr(config, "hc_eps", 1e-6) + self.dtype = config.dtype + self.weight_dtype = config.weight_dtype + self.matmul_precision = jax.lax.Precision(config.matmul_precision) + + # Scale-free unweighted RMSNorm + self.input_norm = DeepSeekV4UnweightedRMSNorm(eps=config.normalization_layer_epsilon) + + # Parameter variables representing learnable linear projections + scale_init = initializers.nd_dense_init(1.0, "fan_in", "normal") + self.hc_fn = nnx.Param( + scale_init( + rngs.params(), + (self.hc_mult * config.emb_dim, self.hc_mult), + self.weight_dtype, + in_axis=0, + out_axis=1, + ), + out_sharding=("activation_embed", None), + ) + self.hc_base = nnx.Param( + initializers.default_bias_init(rngs.params(), (self.hc_mult,), self.weight_dtype), + out_sharding=(None,), + ) + self.hc_scale = nnx.Param( + initializers.default_scalar_init(rngs.params(), (1,), self.weight_dtype), + out_sharding=(None,), + ) + + def __call__(self, x: jax.Array) -> jax.Array: + # x shape: [B, S, k, D] where B = batch_size, S = sequence_length, k = hc_mult, D = emb_dim + b, s, k, d = x.shape + + # 1. Flatten streams and apply scale-free normalization + # [B, S, k, D] -> [B, S, k * D] + flat = self.input_norm(jnp.reshape(x, (b, s, k * d))) + + # 2. Match precision and project flat features to mixing logits + hc_fn = jnp.asarray(self.hc_fn[...], self.dtype) + hc_base = jnp.asarray(self.hc_base[...], self.dtype) + hc_scale = jnp.asarray(self.hc_scale[...], self.dtype) + + # mixes calculation: [B, S, k * D] @ [k * D, k] -> [B, S, k] + mixes = jnp.einsum("bsm,mk -> bsk", flat, hc_fn, precision=self.matmul_precision) + + # mixes sigmoid weights calculation: [B, S, k] + pre = jax.nn.sigmoid(mixes * hc_scale + hc_base[None, None, :]) + self.eps + + # 3. Collapse parallel streams: [B, S, k, D] * [B, S, k] -> [B, S, D] + collapsed = jnp.einsum("bsed,bse -> bsd", x, pre, precision=self.matmul_precision) + return collapsed diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 7691848b7a..659415d421 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -15,6 +15,8 @@ """Tests for DeepSeek-V4 Attention and Compressor parity.""" +import sys +import unittest from collections.abc import Callable from typing import Optional import numpy as np @@ -23,15 +25,18 @@ from torch import nn import jax import jax.numpy as jnp +from jax.sharding import Mesh from flax import nnx -import maxtext.layers.normalizations as jax_norm_module -import maxtext.layers.embeddings as jax_emb_module -import maxtext.layers.linears as jax_linear_module +from maxtext.configs import pyconfig from maxtext.layers.moe import DeepSeekV4TopKRouter, DeepSeekV4HashRouter -from maxtext.layers import attention_compressed -from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding, apply_rotary_pos_emb +from maxtext.layers import attention_compressed, mhc +from maxtext.models.deepseek_v4 import DeepSeekV4DecoderLayer, DeepSeekV4ScannableBlock, DeepSeekV4HyperHead +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding, apply_rotary_pos_emb, Embed from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm from maxtext.layers.linears import DeepSeekGroupedLinear +from maxtext.layers.nnx_decoders import NNXDecoder +import maxtext.common.common_types as ctypes +from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides # ============================================================================== @@ -63,6 +68,7 @@ def __init__(self, **kwargs): self.head_dim = 512 self.q_lora_rank = 1024 self.partial_rotary_factor = 64 / 512 + self.qk_rope_head_dim = 64 self.max_position_embeddings = 1048576 self.rope_theta = 10000.0 self.compress_rope_theta = 160000.0 @@ -81,6 +87,22 @@ def __init__(self, **kwargs): self.attention_dropout = 0.0 self._attn_implementation = "eager" self.matmul_precision = "default" + self.layer_types = ["compressed_sparse_attention"] * 43 + self.mlp_layer_types = ["hash_moe"] * 43 + self.num_experts_per_tok = 6 + self.n_routed_experts = 256 + self.num_local_experts = 256 + self.n_shared_experts = 1 + self.scoring_func = "sqrtsoftplus" + self.routed_scaling_factor = 1.5 + self.intermediate_size = 2048 + self.hidden_act = "silu" + self.swiglu_limit = 10.0 + self.mlp_bias = False + self.attention_bias = False + self.hc_mult = 4 + self.hc_sinkhorn_iters = 20 + self.hc_eps = 1e-6 # Setup default rope parameters dim = int(self.head_dim * self.partial_rotary_factor) @@ -156,6 +178,19 @@ def get_interface(self, implementation_name, default_fn): dynamic_rope_update = lambda fn: fn maybe_autocast = lambda device_type, enabled: torch.enable_grad() # No-op context +use_experts_implementation = lambda cls: cls + + +class TransformersKwargs(dict): + pass + + +ACT2FN = { + "silu": F.silu, + "sigmoid": torch.sigmoid, + "sqrtsoftplus": lambda x: torch.sqrt(F.softplus(x)), +} + # ============================================================================== # 2. EXACT COPY OF PYTORCH REFERENCE CLASSES (SOURCE OF TRUTH - READ ONLY) # ============================================================================== @@ -972,14 +1007,165 @@ def forward( # ============================================================================== -# 3. PYTORCH ROUTER REFERENCE CLASSES (SOURCE OF TRUTH - READ ONLY) +# 2.2 PyTorch Decoder Reference Blocks # ============================================================================== -ACT2FN = { - "sqrtsoftplus": lambda x: torch.sqrt(F.softplus(x)), - "softmax": lambda x: F.softmax(x, dim=-1), - "sigmoid": lambda x: torch.sigmoid(x), -} + +class GradientCheckpointingLayer_PT(nn.Module): + pass + + +class DeepseekV4HyperConnection_PT(nn.Module): + r""" + Manifold-Constrained Hyper-Connections + (mHC) (Xie et al., 2026) to strengthen the conventional residual connections between adjacent + Transformer blocks + + Owns the learned (`fn`, `base`, `scale`) + parameters that turn the incoming `hc_mult` residual streams into collapse / expand + weights. The decoder layer instantiates two of these (one for the attention site, + one for the mlp site). + + ASCII shape guide — `B` = batch, `S` = seq, `H` = hc_mult, `D` = hidden_size:: + + hidden_streams flatten(2) RMSNorm-rescale + F.linear(fn) + [B, S, H, D] ──────────► [B, S, H*D] ─────────────────────────────────► + mix-logits + [B, S, (2+H)*H] + │ + ┌───────────────────────────────────────┴──────────────────────────────┐ + ▼ ▼ ▼ + pre logits post logits comb logits + [B, S, H] [B, S, H] [B, S, H, H] + × scale[0] × scale[1] × scale[2] + + base[:H] + base[H:2H] + base[2H:] + σ() + eps σ() + eps σ() + eps + │ │ │ + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement) row/col normalise + │ + comb + (stream mixer) + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + self.input_norm = DeepseekV4UnweightedRMSNorm_PT(eps=config.rms_norm_eps) + mix = (2 + self.hc_mult) * self.hc_mult + self.fn = nn.Parameter(torch.empty(mix, self.hc_mult * config.hidden_size)) + self.base = nn.Parameter(torch.empty(mix)) + # 3 = number of outputs from the mHC mapping: `pre` (input projection + # weights), `post` (sublayer output projection weights), `comb` (the + # H×H residual combine matrix that gets Sinkhorn-projected onto the + # doubly-stochastic manifold). Each output gets its own learned scale. + self.scale = nn.Parameter(torch.empty(3)) + + def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute `pre`, `post`, `comb` from the mHC mapping (paper §2.2 eq. 8). + `comb` is projected onto the doubly-stochastic manifold via Sinkhorn- + Knopp: starting from the sigmoid-positive matrix, alternate row and + column normalisation for `hc_sinkhorn_iters` steps. `pre` then collapses + the `hc_mult` parallel streams into a single sequence (input projection + into the sublayer); `post` and `comb` are returned for the caller to + apply on the sublayer output. + """ + hc = self.hc_mult + flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) + pre_w, post_w, comb_w = F.linear(flat, self.fn.float()).split([hc, hc, hc * hc], dim=-1) + pre_b, post_b, comb_b = self.base.split([hc, hc, hc * hc]) + pre_scale, post_scale, comb_scale = self.scale.unbind(0) + + pre = torch.sigmoid(pre_w * pre_scale + pre_b) + self.hc_eps + post = 2 * torch.sigmoid(post_w * post_scale + post_b) + comb_logits = comb_w.view(*comb_w.shape[:-1], hc, hc) * comb_scale + comb_b.view(hc, hc) + comb = torch.softmax(comb_logits, dim=-1) + self.hc_eps + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + for _ in range(self.hc_sinkhorn_iters - 1): + comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + # Collapse the `hc_mult` parallel streams down to a single sequence using + # the `pre` weights: one weighted sum across the stream axis, ready for + # the sublayer (attn / MLP). + collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2).to(hidden_streams.dtype) + return post, comb, collapsed + + +DeepseekV4UnweightedRMSNorm = DeepseekV4UnweightedRMSNorm_PT + + +class DeepseekV4HyperHead_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.input_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + self.eps = config.hc_eps + self.hc_fn = nn.Parameter(torch.empty(self.hc_mult, self.hc_mult * config.hidden_size)) + self.hc_base = nn.Parameter(torch.empty(self.hc_mult)) + self.hc_scale = nn.Parameter(torch.empty(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + flat = self.input_norm(x.flatten(2).float()) + mixes = F.linear(flat, self.hc_fn.float()) + pre = torch.sigmoid(mixes * self.hc_scale.float() + self.hc_base.float()) + self.eps + return (pre.unsqueeze(-1) * x).sum(dim=2).to(x.dtype) + + +class DeepseekV4MLP_PT(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +@use_experts_implementation +class DeepseekV4Experts_PT(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + self.limit = config.swiglu_limit + + def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor) -> torch.Tensor: + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + hit = torch.greater(mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(mask[expert_idx]) + current = self._apply_gate(F.linear(hidden_states[token_idx], self.gate_up_proj[expert_idx])) + current = F.linear(current, self.down_proj[expert_idx]) * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, current.to(final.dtype)) + return final + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate, up = gate_up.chunk(2, dim=-1) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + return self.act_fn(gate) * up class DeepseekV4TopKRouter_PT(nn.Module): @@ -1035,7 +1221,152 @@ def forward( return logits, weights * self.routed_scaling_factor, indices -import unittest +class DeepseekV4SparseMoeBlock_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.is_hash = config.mlp_layer_types[layer_idx] == "hash_moe" + self.gate = DeepseekV4HashRouter_PT(config) if self.is_hash else DeepseekV4TopKRouter_PT(config) + self.experts = DeepseekV4Experts_PT(config) + self.shared_experts = DeepseekV4MLP_PT(config) + + def forward(self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None) -> torch.Tensor: + batch, seq_len, hidden_dim = hidden_states.shape + residual = hidden_states + flat = hidden_states.view(-1, hidden_dim) + if self.is_hash: + _, weights, indices = self.gate(hidden_states, input_ids) + else: + _, weights, indices = self.gate(hidden_states) + routed = self.experts(flat, indices, weights).view(batch, seq_len, hidden_dim) + return routed + self.shared_experts(residual) + + +class DeepseekV4DecoderLayer_PT(GradientCheckpointingLayer_PT): + r"""DeepSeek-V4 decoder block (paper §2). Differs from a classic residual block in + two places: + + The residual is a stack of `hc_mult` parallel streams kept in shape + `[B, S, hc_mult, D]` throughout the block, mixed in and out via two + :class:`DeepseekV4HyperConnection` modules (Manifold-Constrained Hyper- + Connections / mHC, paper §2.2; Xie et al., 2026). The mHC mappings constrain + the residual transform to the manifold of doubly-stochastic matrices via the + Sinkhorn-Knopp projection — making signal propagation non-expansive across + deep stacks. + + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.self_attn = DeepseekV4Attention_PT(config, layer_idx) + self.mlp = DeepseekV4SparseMoeBlock_PT(config, layer_idx) + self.input_layernorm = DeepseekV4RMSNorm_PT(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV4RMSNorm_PT(config.hidden_size, eps=config.rms_norm_eps) + self.attn_hc = DeepseekV4HyperConnection_PT(config) + self.ffn_hc = DeepseekV4HyperConnection_PT(config) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + # hidden_states throughout: [B, S, hc_mult, hidden]. + # `post` / `comb` come out of the HC modules in fp32 (Sinkhorn projection runs + # in float); the .to(dtype) puts everything back to the input dtype before mixing + # so both sites stay consistent with `hidden_states`'s entry dtype. + # comb is consumed transposed: indexed as sum_j comb[j, k] * residual[j, d] + # (sum over the FIRST hc axis), equivalent to comb.T @ residual. Sinkhorn + # produces a doubly-stochastic but non-symmetric matrix, so the direction matters. + dtype = hidden_states.dtype + post, comb, collapsed = self.attn_hc(hidden_states) + attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) + hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) + + post, comb, collapsed = self.ffn_hc(hidden_states) + mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) + + +def _make_config(config_pt, B, S, D, **kwargs): + """Return a pyconfig Config object suitable for unit tests.""" + kwargs.pop("layer_types", None) + kwargs.pop("attention_type", None) + num_heads = kwargs.pop("num_attention_heads", config_pt.num_attention_heads) + overrides = { + "run_name": "test_run", + "enable_checkpointing": False, + "model_name": "deepseek_v4-tiny", + "decoder_block": "deepseek_v4", + "dtype": "float32", + "weight_dtype": "float32", + "matmul_precision": "highest", + "per_device_batch_size": B, + "max_target_length": S, + "max_prefill_predict_length": S, + "emb_dim": D, + "mhc_expansion_rate": getattr(config_pt, "hc_mult", 4), + "hc_eps": getattr(config_pt, "hc_eps", 1e-6), + "sinkhorn_iterations": getattr(config_pt, "hc_sinkhorn_iters", 20), + "normalization_layer_epsilon": 1e-6, + "head_dim": config_pt.head_dim, + "dropout_rate": 0.0, + "o_groups": config_pt.o_groups, + "o_lora_rank": config_pt.o_lora_rank, + "compress_ratios": [4] * 43, + "compress_rope_theta": 160000.0, + "sliding_window": config_pt.sliding_window, + "index_n_heads": config_pt.index_n_heads, + "index_head_dim": config_pt.index_head_dim, + "index_topk": config_pt.index_topk, + "base_num_query_heads": num_heads, + "q_lora_rank": config_pt.q_lora_rank, + "qk_rope_head_dim": getattr(config_pt, "qk_rope_head_dim", 64), + "routed_score_func": getattr(config_pt, "scoring_func", "sqrtsoftplus"), + "num_hash_layers": 43, + "rope_max_timescale": config_pt.rope_theta, + "rope_type": "default", + "max_position_embeddings": config_pt.max_position_embeddings, + "shard_mode": "auto", + "debug_sharding": False, + "scan_layers": False, + "remat_policy": "full", + "num_vocab_tiling": 1, + "base_mlp_dim": config_pt.moe_intermediate_size, + "mlp_activations": ["silu"], + "fused_mlp": False, + "megablox": False, + "sparse_matmul": False, + "use_gather_mosaic_kernel": False, + "load_balance_loss_weight": 0.0, + "routed_bias": False, + "dense_init_scale": 1.0, + "moe_expert_input_dim": -1, + "num_experts": 16, + "num_experts_per_tok": 1, + "mlp_bias": False, + "float32_gate_logits": False, + "use_random_routing": False, + "routed_scaling_factor": 1.0, + "attention": "dot_product", + "shared_experts": 1, + "base_moe_mlp_dim": config_pt.moe_intermediate_size, + "vocab_size": getattr(config_pt, "vocab_size", 128), + **kwargs, + } + extra_args = get_decoupled_parallelism_overrides() + merged = {**overrides, **extra_args} + cfg = pyconfig.initialize([sys.argv[0], get_test_config_path()], override_model_config=True, **merged) + if not hasattr(cfg, "trainable_position_size"): + cfg.trainable_position_size = 0 + if not hasattr(cfg, "original_max_position_embeddings"): + cfg.original_max_position_embeddings = cfg.max_position_embeddings + return cfg class DeepSeekV4ParityTest(unittest.TestCase): @@ -1166,154 +1497,12 @@ def test_grouped_linear_parity(self): # Copy the reshaped and transposed weight matrix matching PyTorch's view mapping # [o, i] -> [g, o_g, i] -> [g, i, o_g] - jax_model.weight.value = jnp.array(weight_np.reshape(g, out_features_per_group, i).transpose(0, 2, 1)) + jax_model.kernel.value = jnp.array(weight_np.reshape(g, out_features_per_group, i).transpose(0, 2, 1)) out_jax = jax_model(x_jax) # Verify numerical output parity between frameworks np.testing.assert_allclose(out_torch, out_jax, atol=1e-5, rtol=1e-5) - def test_topk_router_parity(self): - # Generate deterministic random inputs for the router comparison. - np.random.seed(42) - B, S, D = 4, 16, 64 - num_experts = 8 - top_k = 4 - routed_scaling_factor = 1.5 - - hidden_states_np = np.random.randn(B, S, D).astype(np.float32) - # Proactively initialize routing weights with a normal distribution to prevent TPU VM NaN bits. - weight_np = np.random.randn(num_experts, D).astype(np.float32) - bias_np = np.random.randn(num_experts).astype(np.float32) - - # 1. Setup PyTorch Reference Top-K Router - config_pt = DeepseekV4Config() - config_pt.num_experts_per_tok = top_k - config_pt.num_local_experts = num_experts - config_pt.hidden_size = D - config_pt.scoring_func = "sqrtsoftplus" - config_pt.routed_scaling_factor = routed_scaling_factor - - py_router = DeepseekV4TopKRouter_PT(config_pt) - py_router.weight.data.copy_(torch.tensor(weight_np)) - py_router.e_score_correction_bias.copy_(torch.tensor(bias_np)) - - # Run forward on PyTorch reference router - # [B, S, D] -> [B * S, D] -> F.linear() -> logits [B * S, num_experts] -> top_k -> [B * S, top_k] - hidden_states_torch = torch.tensor(hidden_states_np) - py_logits, py_weights, py_indices = py_router(hidden_states_torch) - - # 2. Setup JAX/Flax NNX Equivalent Router - class MockJaxConfig: - - def __init__(self): - self.num_experts_per_tok = top_k - self.num_experts = num_experts - self.emb_dim = D - self.moe_expert_input_dim = D - self.routed_scaling_factor = routed_scaling_factor - self.routed_score_func = "sqrtsoftplus" - self.dtype = jnp.float32 - self.weight_dtype = jnp.float32 - - config_jax = MockJaxConfig() - rngs = nnx.Rngs(42) - # JAX/Flax NNX target initialization - jax_router = DeepSeekV4TopKRouter(config=config_jax, mesh=None, rngs=rngs) - # Copy weights from PyTorch to JAX (transpose because shape is [D, num_experts] in JAX) - # PyTorch weight: [num_experts, D] -> JAX kernel: [D, num_experts] - jax_router.kernel.value = jnp.array(weight_np.T) - jax_router.e_score_correction_bias.value = jnp.array(bias_np) - - # Run forward on JAX router - # [B, S, D] -> flat [B * S, D] -> matmul -> logits [B * S, num_experts] -> top_k -> [B * S, top_k] - hidden_states_jax = jnp.array(hidden_states_np) - jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax) - - # 3. Parity assertions - # Compare raw logits output parity - np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) - - # Sort indices and corresponding weights to ensure order-agnostic parity, - # avoiding differences caused by implementation sorting quirks under sorted=False in PyTorch. - py_sort_idx = np.argsort(py_indices.numpy(), axis=-1) - jax_sort_idx = np.argsort(np.array(jax_indices), axis=-1) - - py_indices_sorted = np.take_along_axis(py_indices.numpy(), py_sort_idx, axis=-1) - jax_indices_sorted = np.take_along_axis(np.array(jax_indices), jax_sort_idx, axis=-1) - - py_weights_sorted = np.take_along_axis(py_weights.detach().numpy(), py_sort_idx, axis=-1) - jax_weights_sorted = np.take_along_axis(np.array(jax_weights), jax_sort_idx, axis=-1) - - np.testing.assert_array_equal(jax_indices_sorted, py_indices_sorted) - np.testing.assert_allclose(jax_weights_sorted, py_weights_sorted, atol=1e-5, rtol=1e-5) - - def test_hash_router_parity(self): - # Generate deterministic random inputs for Hash Router comparison. - np.random.seed(42) - B, S, D = 4, 16, 64 - num_experts = 8 - top_k = 4 - routed_scaling_factor = 1.5 - vocab_size = 128 - - hidden_states_np = np.random.randn(B, S, D).astype(np.float32) - # Generate input token IDs to lookup static hash routing indices. - input_ids_np = np.random.randint(0, vocab_size, size=(B, S)).astype(np.int64) - - weight_np = np.random.randn(num_experts, D).astype(np.float32) - # Setup static routing table - tid2eid_np = np.random.randint(0, num_experts, size=(vocab_size, top_k)).astype(np.int64) - - # 1. Setup PyTorch Reference Hash Router - config_pt = DeepseekV4Config() - config_pt.num_experts_per_tok = top_k - config_pt.num_local_experts = num_experts - config_pt.hidden_size = D - config_pt.scoring_func = "sqrtsoftplus" - config_pt.routed_scaling_factor = routed_scaling_factor - config_pt.vocab_size = vocab_size - - py_router = DeepseekV4HashRouter_PT(config_pt) - py_router.weight.data.copy_(torch.tensor(weight_np)) - py_router.tid2eid.copy_(torch.tensor(tid2eid_np)) - - # Run forward on PyTorch reference router - hidden_states_torch = torch.tensor(hidden_states_np) - input_ids_torch = torch.tensor(input_ids_np) - py_logits, py_weights, py_indices = py_router(hidden_states_torch, input_ids_torch) - - # 2. Setup JAX/Flax NNX Equivalent Router - class MockJaxConfig: - - def __init__(self): - self.num_experts_per_tok = top_k - self.num_experts = num_experts - self.emb_dim = D - self.moe_expert_input_dim = D - self.routed_scaling_factor = routed_scaling_factor - self.routed_score_func = "sqrtsoftplus" - self.vocab_size = vocab_size - self.dtype = jnp.float32 - self.weight_dtype = jnp.float32 - - config_jax = MockJaxConfig() - rngs = nnx.Rngs(42) - jax_router = DeepSeekV4HashRouter(config=config_jax, mesh=None, rngs=rngs) - # Copy weight and lookup table parameter states. - jax_router.kernel.value = jnp.array(weight_np.T) - jax_router.tid2eid.value = jnp.array(tid2eid_np, dtype=jnp.int32) - - # Run forward on JAX router - hidden_states_jax = jnp.array(hidden_states_np) - input_ids_jax = jnp.array(input_ids_np) - jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax, input_ids_jax) - - # 3. Parity assertions - # Logits, weights, and selected index array checks. - np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) - np.testing.assert_array_equal(jax_indices, py_indices.numpy()) - np.testing.assert_allclose(py_weights.detach().numpy(), jax_weights, atol=1e-5, rtol=1e-5) - def test_hca_compressor_parity(self): # Configure deterministic seeds for parity reproducibility np.random.seed(42) @@ -1333,6 +1522,7 @@ def test_hca_compressor_parity(self): config = DeepseekV4Config() config.hidden_size = D config.head_dim = D_head + config.qk_rope_head_dim = int(D_head * (64 / 512)) config.compress_rates["heavily_compressed_attention"] = compress_rate config.rms_norm_eps = 1e-6 @@ -1341,9 +1531,7 @@ def test_hca_compressor_parity(self): torch.nn.init.normal_(torch_model.position_bias, std=0.02) # Map JAX layer using matching parameters - jax_config = DeepseekV4Config() - jax_config.compress_ratios = [compress_rate] * 43 - jax_config.compress_rope_theta = 160000.0 + jax_config = _make_config(config, B, S, D, compress_ratios=[compress_rate] * 43) rngs = nnx.Rngs(42) jax_model = attention_compressed.HCACompressor( @@ -1423,12 +1611,16 @@ def test_indexer_parity(self): torch.nn.init.normal_(torch_model.position_bias, std=0.02) # Map JAX equivalent Indexer module - jax_config = DeepseekV4Config() - jax_config.index_n_heads = num_heads - jax_config.index_head_dim = index_head_dim - jax_config.index_topk = index_topk - jax_config.compress_ratios = [compress_rate] * 43 - jax_config.compress_rope_theta = 160000.0 + jax_config = _make_config( + config, + B, + S, + D, + index_n_heads=num_heads, + index_head_dim=index_head_dim, + index_topk=index_topk, + compress_ratios=[compress_rate] * 43, + ) rngs = nnx.Rngs(42) jax_model = attention_compressed.DeepSeekV4Indexer( @@ -1494,6 +1686,7 @@ def test_csa_compressor_parity(self): config.hidden_size = D config.q_lora_rank = D_rank config.head_dim = D_head + config.qk_rope_head_dim = int(D_head * (64 / 512)) config.index_n_heads = num_heads config.index_head_dim = index_head_dim config.index_topk = index_topk @@ -1504,14 +1697,16 @@ def test_csa_compressor_parity(self): torch.nn.init.normal_(torch_model.position_bias, std=0.02) torch.nn.init.normal_(torch_model.indexer.position_bias, std=0.02) - # JAX CSA Compressor - jax_config = DeepseekV4Config() - jax_config.head_dim = D_head - jax_config.index_n_heads = num_heads - jax_config.index_head_dim = index_head_dim - jax_config.index_topk = index_topk - jax_config.compress_ratios = [compress_rate] * 43 - jax_config.compress_rope_theta = 160000.0 + jax_config = _make_config( + config, + B, + S, + D, + index_n_heads=num_heads, + index_head_dim=index_head_dim, + index_topk=index_topk, + compress_ratios=[compress_rate] * 43, + ) rngs = nnx.Rngs(42) jax_model = attention_compressed.CSACompressor( @@ -1579,8 +1774,6 @@ def test_csa_compressor_parity(self): ) topk_jax_np = np.array(topk_jax) - print("topk_torch:", topk_torch) - print("topk_jax:", topk_jax_np) np.testing.assert_allclose(topk_torch, topk_jax_np, atol=1e-5, rtol=1e-5) # Check complete parity of gathered/indexed keys @@ -1608,6 +1801,7 @@ def test_attention_layer_parity(self): config.hidden_size = D config.q_lora_rank = D_rank config.head_dim = D_head + config.qk_rope_head_dim = int(D_head * (64 / 512)) config.num_attention_heads = num_heads config.num_key_value_heads = 1 config.compress_rates["heavily_compressed_attention"] = compress_rate @@ -1616,7 +1810,7 @@ def test_attention_layer_parity(self): # Generate reference position embeddings (cos, sin) torch_emb = DeepseekV4RotaryEmbedding_PT(config) - cos_torch, sin_torch = torch_emb(x_torch, position_ids_torch, layer_type="main") + cos_torch, sin_torch = torch_emb(x_torch, position_ids_torch, layer_type="compress") cos_jax = jnp.array(cos_torch.detach().numpy()) sin_jax = jnp.array(sin_torch.detach().numpy()) @@ -1627,16 +1821,17 @@ def test_attention_layer_parity(self): if torch_model.compressor is not None: torch.nn.init.normal_(torch_model.compressor.position_bias, std=0.02) - jax_config = DeepseekV4Config() - jax_config.hidden_size = D - jax_config.q_lora_rank = D_rank - jax_config.head_dim = D_head - jax_config.num_attention_heads = num_heads - jax_config.compress_ratios = [compress_rate] * 10 - jax_config.compress_rope_theta = 160000.0 - jax_config.layer_types = ["heavily_compressed_attention"] * 10 - jax_config.o_groups = config.o_groups - jax_config.o_lora_rank = config.o_lora_rank + jax_config = _make_config( + config, + B, + S, + D, + num_attention_heads=num_heads, + compress_ratios=[compress_rate] * 10, + layer_types=["heavily_compressed_attention"] * 10, + o_groups=config.o_groups, + o_lora_rank=config.o_lora_rank, + ) rngs = nnx.Rngs(42) jax_model = attention_compressed.DeepSeekV4Attention( @@ -1647,6 +1842,7 @@ def test_attention_layer_parity(self): config=jax_config, layer_idx=0, eps=1e-6, + attention_type="heavily_compressed_attention", rngs=rngs, ) @@ -1662,7 +1858,7 @@ def test_attention_layer_parity(self): w_o_a_np = torch_model.o_a_proj.weight.detach().numpy() in_features_per_group = num_heads * D_head // config.o_groups w_o_a_np = w_o_a_np.reshape(config.o_groups, -1, in_features_per_group).transpose(0, 2, 1) - jax_model.o_a_proj.weight[...] = jnp.array(w_o_a_np) + jax_model.o_a_proj.kernel[...] = jnp.array(w_o_a_np) jax_model.o_b_proj.kernel[...] = jnp.array(torch_model.o_b_proj.weight.detach().numpy().T) jax_model.sinks[...] = jnp.array(torch_model.sinks.detach().numpy()) @@ -1696,6 +1892,567 @@ def test_attention_layer_parity(self): # Check complete numerical parity of coordination attention layers np.testing.assert_allclose(out_torch_np, out_jax_np, atol=1e-5, rtol=1e-5) + def test_topk_router_parity(self): + # Generate deterministic random inputs for the router comparison. + np.random.seed(42) + B, S, D = 2, 8, 64 + num_experts = 16 + top_k = 6 + routed_scaling_factor = 1.5 + + hidden_states_np = np.random.randn(B, S, D).astype(np.float32) + weight_np = np.random.randn(num_experts, D).astype(np.float32) + e_score_correction_bias_np = np.random.randn(num_experts).astype(np.float32) + + # 1. Setup PyTorch Reference Router + config_pt = DeepseekV4Config( + num_experts_per_tok=top_k, + num_local_experts=num_experts, + hidden_size=D, + routed_scaling_factor=routed_scaling_factor, + scoring_func="sqrtsoftplus", + ) + py_router = DeepseekV4TopKRouter_PT(config_pt) + py_router.weight.data = torch.tensor(weight_np) + py_router.e_score_correction_bias.data = torch.tensor(e_score_correction_bias_np) + + # Run forward on PyTorch router + hidden_states_torch = torch.tensor(hidden_states_np) + py_logits, py_weights, py_indices = py_router(hidden_states_torch) + + # 2. Setup JAX/Flax NNX Equivalent Router + class MockJaxConfig: + + def __init__(self): + self.num_experts_per_tok = top_k + self.num_experts = num_experts + self.emb_dim = D + self.moe_expert_input_dim = D + self.routed_scaling_factor = routed_scaling_factor + self.routed_score_func = "sqrtsoftplus" + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_router = DeepSeekV4TopKRouter(config=config_jax, mesh=None, rngs=rngs) + + # Copy weight and correction bias parameters using Flax NNX attribute variable assignments. + jax_router.kernel[...] = jnp.array(weight_np.T) + jax_router.e_score_correction_bias[...] = jnp.array(e_score_correction_bias_np) + + # Run forward on JAX router + hidden_states_jax = jnp.array(hidden_states_np) + jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax) + + # 3. Parity assertions + # Compare raw logits directly. + np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) + + # Symmetrically, the order of the chosen top-k experts can differ (unsorted vs JAX sort). + # Sort both index selections and weight selections row-by-row (token-by-token) before comparison. + py_ind_np = py_indices.numpy() + py_w_np = py_weights.detach().numpy() + jax_ind_np = np.array(jax_indices) + jax_w_np = np.array(jax_weights) + + # Sort index arrays row-by-row, and order the corresponding weights array matching the index sort order. + for i in range(py_ind_np.shape[0]): + py_sort_order = np.argsort(py_ind_np[i]) + py_ind_np[i] = py_ind_np[i][py_sort_order] + py_w_np[i] = py_w_np[i][py_sort_order] + + jax_sort_order = np.argsort(jax_ind_np[i]) + jax_ind_np[i] = jax_ind_np[i][jax_sort_order] + jax_w_np[i] = jax_w_np[i][jax_sort_order] + + # Assert sorted indices and weights are mathematically identical! + np.testing.assert_array_equal(jax_ind_np, py_ind_np) + np.testing.assert_allclose(py_w_np, jax_w_np, atol=1e-5, rtol=1e-5) + + def test_hash_router_parity(self): + # Generate deterministic random inputs for static hash router comparison. + np.random.seed(42) + B, S, D = 2, 8, 64 + num_experts = 16 + top_k = 6 + routed_scaling_factor = 1.5 + vocab_size = 32 + + hidden_states_np = np.random.randn(B, S, D).astype(np.float32) + input_ids_np = np.random.randint(0, vocab_size, size=(B, S)).astype(np.int32) + weight_np = np.random.randn(num_experts, D).astype(np.float32) + tid2eid_np = np.random.randint(0, num_experts, size=(vocab_size, top_k)).astype(np.int32) + + # 1. Setup PyTorch Reference Router + config_pt = DeepseekV4Config( + num_experts_per_tok=top_k, + num_local_experts=num_experts, + hidden_size=D, + routed_scaling_factor=routed_scaling_factor, + vocab_size=vocab_size, + scoring_func="sqrtsoftplus", + ) + py_router = DeepseekV4HashRouter_PT(config_pt) + py_router.weight.data = torch.tensor(weight_np) + py_router.tid2eid.data = torch.tensor(tid2eid_np).long() + + # Run forward on PyTorch router + hidden_states_torch = torch.tensor(hidden_states_np) + input_ids_torch = torch.tensor(input_ids_np) + py_logits, py_weights, py_indices = py_router(hidden_states_torch, input_ids_torch) + + # 2. Setup JAX/Flax NNX Equivalent Router + class MockJaxConfig: + + def __init__(self): + self.num_experts_per_tok = top_k + self.num_experts = num_experts + self.emb_dim = D + self.moe_expert_input_dim = D + self.routed_scaling_factor = routed_scaling_factor + self.routed_score_func = "sqrtsoftplus" + self.vocab_size = vocab_size + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_router = DeepSeekV4HashRouter(config=config_jax, mesh=None, rngs=rngs) + + # Copy weight and lookup table parameter states using clean Flax NNX assignments. + jax_router.kernel[...] = jnp.array(weight_np.T) + jax_router.tid2eid[...] = jnp.array(tid2eid_np, dtype=jnp.int32) + + # Run forward on JAX router + hidden_states_jax = jnp.array(hidden_states_np) + input_ids_jax = jnp.array(input_ids_np) + jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax, input_ids_jax) + + # 3. Parity assertions + # Logits, weights, and selected index array checks. + np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) + np.testing.assert_array_equal(jax_indices, py_indices.numpy()) + np.testing.assert_allclose(py_weights.detach().numpy(), jax_weights, atol=1e-5, rtol=1e-5) + + def test_hyperhead_parity(self): + # Verify isolated parametric collapse HyperHead parity E2E! + np.random.seed(42) + B, S, k, D = 2, 4, 4, 128 + x_np = np.random.randn(B, S, k, D).astype(np.float32) + hc_fn_np = np.random.randn(k, k * D).astype(np.float32) + hc_base_np = np.random.randn(k).astype(np.float32) + hc_scale_np = np.random.randn(1).astype(np.float32) + + config_pt = DeepseekV4Config( + hc_mult=k, + hidden_size=D, + rms_norm_eps=1e-6, + hc_eps=1e-6, + ) + py_head = DeepseekV4HyperHead_PT(config_pt) + py_head.hc_fn.data = torch.tensor(hc_fn_np) + py_head.hc_base.data = torch.tensor(hc_base_np) + py_head.hc_scale.data = torch.tensor(hc_scale_np) + + # Run forward on PyTorch reference + x_torch = torch.tensor(x_np) + out_torch = py_head(x_torch) + + # Setup JAX DeepSeekV4HyperHead equivalent NNX module + class MockJaxConfig: + + def __init__(self): + self.emb_dim = D + self.mhc_expansion_rate = k + self.hc_eps = 1e-6 + self.normalization_layer_epsilon = 1e-6 + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + self.matmul_precision = "default" + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_head = DeepSeekV4HyperHead(config=config_jax, rngs=rngs) + + # Copy weight matrices and parameter states cleanly + # Shape mappings: + # PyTorch: hc_fn has shape [k, k * D], mixes = F.linear(flat, hc_fn) -> flat @ hc_fn.T + # JAX: hc_fn has shape [k * D, k], mixes = flat @ hc_fn + # Therefore, JAX weight = PyTorch weight.T + jax_head.hc_fn[...] = jnp.array(hc_fn_np.T) + jax_head.hc_base[...] = jnp.array(hc_base_np) + jax_head.hc_scale[...] = jnp.array(hc_scale_np) + + # Run forward passes on identical random batch stream inputs [B, S, k, D] + x_jax = jnp.array(x_np) + out_jax = jax_head(x_jax) + + # Assert bit-accurate numerical parity down to atol=1e-5 E2E! + np.testing.assert_allclose(out_torch.detach().numpy(), np.array(out_jax), atol=1e-5, rtol=1e-5) + + def test_full_model_stack_parity(self): + """Verifies complete, scannable multi-layer decoder stack E2E logits parity. + + This E2E test validates that: + 1. Parallel stream transformations [B, S, hc_mult, D] sequence correctly. + 2. Manifold-Constrained Hyper-Connections (mHC) perform identical Sinkhorn + projections across frameworks. + 3. The JAX scanned compiler (scan_layers = True) constructs and executes + identical stacked loop parameters compared to unrolled modes (scan_layers = False). + """ + np.random.seed(42) + B, S, D, H_mult, vocab_size, num_layers = 2, 8, 128, 4, 32, 3 + + # Generate identical input token IDs across frameworks + input_ids_np = np.random.randint(0, vocab_size, size=(B, S)).astype(np.int32) + position_ids_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + input_ids_torch = torch.tensor(input_ids_np).long() + position_ids_torch = torch.tensor(position_ids_np).long() + input_ids_jax = jnp.array(input_ids_np) + + # 1. Build identical configuration configurations + config_pt = DeepseekV4Config() + config_pt.hidden_size = D + config_pt.intermediate_size = 64 + config_pt.moe_intermediate_size = 64 + config_pt.hc_mult = H_mult + config_pt.hc_sinkhorn_iters = 8 + config_pt.rms_norm_eps = 1e-6 + config_pt.vocab_size = vocab_size + config_pt.num_hash_layers = 2 + config_pt.num_local_experts = 4 + config_pt.num_experts_per_tok = 2 + config_pt.num_attention_heads = 4 + config_pt.num_key_value_heads = 1 + config_pt.head_dim = 32 + config_pt.qk_rope_head_dim = 32 + config_pt.rope_parameters["main"]["partial_rotary_factor"] = 1.0 + config_pt.rope_parameters["compress"]["partial_rotary_factor"] = 1.0 + config_pt.q_lora_rank = 64 + config_pt.o_groups = 2 + config_pt.o_lora_rank = 64 + config_pt.index_n_heads = 4 + config_pt.index_head_dim = 32 + config_pt.index_topk = 2 + config_pt.layer_types = ["compressed_sparse_attention", "heavily_compressed_attention", "compressed_sparse_attention"] + config_pt.mlp_layer_types = ["hash_moe", "hash_moe", "topk_moe"] + + class DeepseekV4DecoderStack_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config, num_layers: int): + super().__init__() + self.layers = nn.ModuleList([DeepseekV4DecoderLayer_PT(config, lyr) for lyr in range(num_layers)]) + self.hc_head = DeepseekV4HyperHead_PT(config) + self.norm = DeepseekV4RMSNorm_PT(config.hidden_size, eps=config.rms_norm_eps) + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.logits_dense = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rotary_emb = DeepseekV4RotaryEmbedding_PT(config) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + y = self.embeddings(input_ids) + y = y.unsqueeze(2).expand(-1, -1, 4, -1) + cos, sin = self.rotary_emb(y[:, :, 0, :], position_ids, layer_type="compress") + for layer in self.layers: + y = layer( + y, input_ids=input_ids, position_embeddings=(cos, sin), position_ids=position_ids, attention_mask=None + ) + collapsed = self.hc_head(y) + normed = self.norm(collapsed) + logits = self.logits_dense(normed) + return logits + + decoder_pt = DeepseekV4DecoderStack_PT(config_pt, num_layers) + + torch.nn.init.normal_(decoder_pt.embeddings.weight, std=0.02) + torch.nn.init.normal_(decoder_pt.norm.weight, std=0.02) + torch.nn.init.normal_(decoder_pt.logits_dense.weight, std=0.02) + torch.nn.init.normal_(decoder_pt.hc_head.hc_fn, std=0.02) + torch.nn.init.normal_(decoder_pt.hc_head.hc_base, std=0.02) + torch.nn.init.normal_(decoder_pt.hc_head.hc_scale, std=0.02) + + for layer_pt in decoder_pt.layers: + for param in [ + layer_pt.attn_hc.fn, + layer_pt.attn_hc.base, + layer_pt.attn_hc.scale, + layer_pt.ffn_hc.fn, + layer_pt.ffn_hc.base, + layer_pt.ffn_hc.scale, + layer_pt.self_attn.q_a_proj.weight, + layer_pt.self_attn.q_a_norm.weight, + layer_pt.self_attn.q_b_proj.weight, + layer_pt.self_attn.kv_proj.weight, + layer_pt.self_attn.kv_norm.weight, + layer_pt.self_attn.o_a_proj.weight, + layer_pt.self_attn.o_b_proj.weight, + layer_pt.self_attn.sinks, + layer_pt.mlp.gate.weight, + layer_pt.mlp.experts.gate_up_proj, + layer_pt.mlp.experts.down_proj, + layer_pt.mlp.shared_experts.gate_proj.weight, + layer_pt.mlp.shared_experts.up_proj.weight, + layer_pt.mlp.shared_experts.down_proj.weight, + layer_pt.input_layernorm.weight, + layer_pt.post_attention_layernorm.weight, + ]: + torch.nn.init.normal_(param, std=0.02) + if layer_pt.self_attn.compressor is not None: + comp_pt = layer_pt.self_attn.compressor + for param in [comp_pt.kv_proj.weight, comp_pt.gate_proj.weight, comp_pt.position_bias, comp_pt.kv_norm.weight]: + torch.nn.init.normal_(param, std=0.02) + if hasattr(comp_pt, "indexer"): + for param in [ + comp_pt.indexer.kv_proj.weight, + comp_pt.indexer.gate_proj.weight, + comp_pt.indexer.position_bias, + comp_pt.indexer.kv_norm.weight, + comp_pt.indexer.q_b_proj.weight, + comp_pt.indexer.weights_proj.weight, + ]: + torch.nn.init.normal_(param, std=0.02) + + logits_torch = decoder_pt(input_ids_torch, position_ids_torch).detach().numpy() + + devices = jax.devices() + mesh = Mesh(np.array(devices), ("data",)) + + for scan_mode in [False, True]: + jax_config = _make_config( + config_pt, + B, + S, + D, + base_num_decoder_layers=num_layers, + logits_via_embedding=False, + logits_dot_in_fp32=True, + parameter_memory_host_offload=False, + param_scan_axis=0, + use_iota_embed=False, + num_experts=config_pt.num_local_experts, + num_experts_per_tok=config_pt.num_experts_per_tok, + num_hash_layers=config_pt.num_hash_layers, + gradient_accumulation_steps=1, + hardware="cpu", + megablox=False, + sparse_matmul=False, + use_gather_mosaic_kernel=False, + num_vocab_tiling=1, + compress_ratios=[4, 128, 4] * 15, + mlp_dim=config_pt.intermediate_size, + num_attention_heads=config_pt.num_attention_heads, + q_lora_rank=config_pt.q_lora_rank, + head_dim=config_pt.head_dim, + o_groups=config_pt.o_groups, + o_lora_rank=config_pt.o_lora_rank, + index_n_heads=config_pt.index_n_heads, + index_head_dim=config_pt.index_head_dim, + index_topk=config_pt.index_topk, + mlp_activations=["silu", "linear"], + scan_layers=scan_mode, + ) + decoder_jax = NNXDecoder(config=jax_config, mesh=mesh, rngs=nnx.Rngs(0)) + + def get_jax_layer(decoder, lyr): + if not scan_mode: + return decoder.layers[lyr] + return ( + getattr(decoder.layers, f"layers_{lyr}") + if lyr < 2 + else getattr(decoder.layers_remainder, f"layers_{lyr - 2}") + ) + + shared_embedding = Embed(vocab_size, D, config=jax_config, mesh=mesh, rngs=nnx.Rngs(0)) + shared_embedding.embedding[...] = jnp.array(decoder_pt.embeddings.weight.detach().numpy()) + decoder_jax.decoder_norm.scale[...] = jnp.array(decoder_pt.norm.weight.detach().numpy()) + decoder_jax.logits_dense.kernel[...] = jnp.array(decoder_pt.logits_dense.weight.detach().numpy().T) + decoder_jax.hc_head.hc_fn[...] = jnp.array(decoder_pt.hc_head.hc_fn.detach().numpy().T) + decoder_jax.hc_head.hc_base[...] = jnp.array(decoder_pt.hc_head.hc_base.detach().numpy()) + decoder_jax.hc_head.hc_scale[...] = jnp.array(decoder_pt.hc_head.hc_scale.detach().numpy()) + + def assign_param(jax_param, pt_value, lyr): + if hasattr(jax_param, "val"): + jax_param[...] = pt_value + else: + jax_param[...] = jnp.expand_dims(pt_value, axis=0) if (scan_mode and lyr < 2) else pt_value + + hc = H_mult + for lyr in range(num_layers): + layer_jax, layer_pt = get_jax_layer(decoder_jax, lyr), decoder_pt.layers[lyr] + assign_param( + layer_jax.pre_self_attention_layer_norm.scale, + jnp.array(layer_pt.input_layernorm.weight.detach().numpy()), + lyr, + ) + assign_param( + layer_jax.post_self_attention_layer_norm.scale, + jnp.array(layer_pt.post_attention_layernorm.weight.detach().numpy()), + lyr, + ) + + assign_param( + layer_jax.self_attention.q_a_proj.kernel, + jnp.array(layer_pt.self_attn.q_a_proj.weight.detach().numpy().T), + lyr, + ) + assign_param( + layer_jax.self_attention.q_a_norm.weight, jnp.array(layer_pt.self_attn.q_a_norm.weight.detach().numpy()), lyr + ) + assign_param( + layer_jax.self_attention.q_b_proj.kernel, + jnp.array(layer_pt.self_attn.q_b_proj.weight.detach().numpy().T), + lyr, + ) + assign_param( + layer_jax.self_attention.kv_proj.kernel, jnp.array(layer_pt.self_attn.kv_proj.weight.detach().numpy().T), lyr + ) + assign_param( + layer_jax.self_attention.kv_norm.weight, jnp.array(layer_pt.self_attn.kv_norm.weight.detach().numpy()), lyr + ) + + w_o_a_np = layer_pt.self_attn.o_a_proj.weight.detach().numpy() + in_features_per_group = config_pt.num_attention_heads * config_pt.head_dim // config_pt.o_groups + w_o_a_np = w_o_a_np.reshape(config_pt.o_groups, -1, in_features_per_group).transpose(0, 2, 1) + assign_param(layer_jax.self_attention.o_a_proj.kernel, jnp.array(w_o_a_np), lyr) + + assign_param( + layer_jax.self_attention.o_b_proj.kernel, + jnp.array(layer_pt.self_attn.o_b_proj.weight.detach().numpy().T), + lyr, + ) + assign_param(layer_jax.self_attention.sinks, jnp.array(layer_pt.self_attn.sinks.detach().numpy()), lyr) + + if layer_pt.self_attn.compressor is not None: + comp_pt = layer_pt.self_attn.compressor + comp_jax = layer_jax.self_attention.compressor + assign_param(comp_jax.kv_proj.kernel, jnp.array(comp_pt.kv_proj.weight.detach().numpy().T), lyr) + assign_param(comp_jax.gate_proj.kernel, jnp.array(comp_pt.gate_proj.weight.detach().numpy().T), lyr) + assign_param(comp_jax.position_bias, jnp.array(comp_pt.position_bias.detach().numpy()), lyr) + assign_param(comp_jax.kv_norm.weight, jnp.array(comp_pt.kv_norm.weight.detach().numpy()), lyr) + if hasattr(comp_pt, "indexer"): + assign_param( + comp_jax.indexer.kv_proj.kernel, jnp.array(comp_pt.indexer.kv_proj.weight.detach().numpy().T), lyr + ) + assign_param( + comp_jax.indexer.gate_proj.kernel, jnp.array(comp_pt.indexer.gate_proj.weight.detach().numpy().T), lyr + ) + assign_param(comp_jax.indexer.position_bias, jnp.array(comp_pt.indexer.position_bias.detach().numpy()), lyr) + assign_param(comp_jax.indexer.kv_norm.weight, jnp.array(comp_pt.indexer.kv_norm.weight.detach().numpy()), lyr) + assign_param( + comp_jax.indexer.q_b_proj.kernel, jnp.array(comp_pt.indexer.q_b_proj.weight.detach().numpy().T), lyr + ) + assign_param( + comp_jax.indexer.weights_proj.kernel, + jnp.array(comp_pt.indexer.weights_proj.weight.detach().numpy().T), + lyr, + ) + + moe_pt = layer_pt.mlp + moe_jax = layer_jax.mlp + assign_param(moe_jax.MoeBlock_0.gate.kernel, jnp.array(moe_pt.gate.weight.detach().numpy().T), lyr) + if moe_pt.is_hash: + assign_param( + moe_jax.MoeBlock_0.gate.tid2eid, jnp.array(moe_pt.gate.tid2eid.detach().numpy(), dtype=jnp.int32), lyr + ) + else: + assign_param( + moe_jax.MoeBlock_0.gate.e_score_correction_bias, + jnp.array(moe_pt.gate.e_score_correction_bias.detach().numpy()), + lyr, + ) + + gate_up_np = moe_pt.experts.gate_up_proj.detach().numpy() + intermediate_dim = config_pt.intermediate_size + wi_0_np = gate_up_np[:, :intermediate_dim, :].transpose(0, 2, 1) + wi_1_np = gate_up_np[:, intermediate_dim:, :].transpose(0, 2, 1) + wo_np = moe_pt.experts.down_proj.detach().numpy().transpose(0, 2, 1) + + assign_param(moe_jax.MoeBlock_0.wi_0, jnp.array(wi_0_np), lyr) + assign_param(moe_jax.MoeBlock_0.wi_1, jnp.array(wi_1_np), lyr) + assign_param(moe_jax.MoeBlock_0.wo, jnp.array(wo_np), lyr) + + assign_param( + moe_jax.shared_experts.wi_0.kernel, jnp.array(moe_pt.shared_experts.gate_proj.weight.detach().numpy().T), lyr + ) + assign_param( + moe_jax.shared_experts.wi_1.kernel, jnp.array(moe_pt.shared_experts.up_proj.weight.detach().numpy().T), lyr + ) + assign_param( + moe_jax.shared_experts.wo.kernel, jnp.array(moe_pt.shared_experts.down_proj.weight.detach().numpy().T), lyr + ) + + assign_param(layer_jax.mhc_attention.pre_alpha, jnp.array(layer_pt.attn_hc.fn.detach().numpy()[:hc].T), lyr) + assign_param( + layer_jax.mhc_attention.post_alpha, jnp.array(layer_pt.attn_hc.fn.detach().numpy()[hc : 2 * hc].T), lyr + ) + assign_param(layer_jax.mhc_attention.res_alpha, jnp.array(layer_pt.attn_hc.fn.detach().numpy()[2 * hc :].T), lyr) + assign_param(layer_jax.mhc_attention.pre_beta, jnp.array(layer_pt.attn_hc.base.detach().numpy()[:hc]), lyr) + assign_param( + layer_jax.mhc_attention.post_beta, jnp.array(layer_pt.attn_hc.base.detach().numpy()[hc : 2 * hc]), lyr + ) + assign_param( + layer_jax.mhc_attention.res_beta, + jnp.array(layer_pt.attn_hc.base.detach().numpy()[2 * hc :].reshape(hc, hc)), + lyr, + ) + assign_param(layer_jax.mhc_attention.pre_alpha_scale, jnp.array([layer_pt.attn_hc.scale[0].item()]), lyr) + assign_param(layer_jax.mhc_attention.post_alpha_scale, jnp.array([layer_pt.attn_hc.scale[1].item()]), lyr) + assign_param(layer_jax.mhc_attention.res_alpha_scale, jnp.array([layer_pt.attn_hc.scale[2].item()]), lyr) + + assign_param(layer_jax.mhc_mlp.pre_alpha, jnp.array(layer_pt.ffn_hc.fn.detach().numpy()[:hc].T), lyr) + assign_param(layer_jax.mhc_mlp.post_alpha, jnp.array(layer_pt.ffn_hc.fn.detach().numpy()[hc : 2 * hc].T), lyr) + assign_param(layer_jax.mhc_mlp.res_alpha, jnp.array(layer_pt.ffn_hc.fn.detach().numpy()[2 * hc :].T), lyr) + assign_param(layer_jax.mhc_mlp.pre_beta, jnp.array(layer_pt.ffn_hc.base.detach().numpy()[:hc]), lyr) + assign_param(layer_jax.mhc_mlp.post_beta, jnp.array(layer_pt.ffn_hc.base.detach().numpy()[hc : 2 * hc]), lyr) + assign_param( + layer_jax.mhc_mlp.res_beta, jnp.array(layer_pt.ffn_hc.base.detach().numpy()[2 * hc :].reshape(hc, hc)), lyr + ) + assign_param(layer_jax.mhc_mlp.pre_alpha_scale, jnp.array([layer_pt.ffn_hc.scale[0].item()]), lyr) + assign_param(layer_jax.mhc_mlp.post_alpha_scale, jnp.array([layer_pt.ffn_hc.scale[1].item()]), lyr) + assign_param(layer_jax.mhc_mlp.res_alpha_scale, jnp.array([layer_pt.ffn_hc.scale[2].item()]), lyr) + + if not scan_mode: + y_pt = decoder_pt.embeddings(input_ids_torch) + y_pt = y_pt.unsqueeze(2).expand(-1, -1, 4, -1) + cos_pt, sin_pt = decoder_pt.rotary_emb(y_pt[:, :, 0, :], position_ids_torch, layer_type="compress") + + y_jax = shared_embedding(input_ids_jax.astype("int32"), model_mode="train") + y_jax = jnp.repeat(jnp.expand_dims(y_jax, axis=2), 4, axis=2).astype(y_jax.dtype) + + np.testing.assert_allclose( + y_pt.detach().numpy(), np.array(y_jax), atol=1e-5, rtol=1e-5, err_msg="Embedding mismatch" + ) + + for lyr in range(num_layers): + layer_pt = decoder_pt.layers[lyr] + layer_jax = decoder_jax.layers[lyr] + + y_pt = layer_pt( + y_pt, + input_ids=input_ids_torch, + position_embeddings=(cos_pt, sin_pt), + position_ids=position_ids_torch, + attention_mask=None, + ) + y_jax, _ = layer_jax( + y_jax, + decoder_segment_ids=jnp.zeros((B, S), dtype=jnp.int32), + decoder_positions=jnp.array(position_ids_np, dtype=jnp.int32), + deterministic=True, + model_mode="train", + decoder_input_tokens=input_ids_jax, + ) + + logits_jax, _, _ = decoder_jax( + shared_embedding=shared_embedding, + decoder_input_tokens=input_ids_jax, + decoder_positions=jnp.array(position_ids_np, dtype=jnp.int32), + decoder_segment_ids=jnp.zeros((B, S), dtype=jnp.int32), + deterministic=True, + ) + + np.testing.assert_allclose(logits_torch, np.array(logits_jax), atol=1e-5, rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index 2525a181f1..bf527ee92e 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -767,3 +767,40 @@ def test_gemma4_scanned_layers(self): logits.shape, (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size), ) + + def test_deepseek_v4_scanned_layers(self): + """Test NNXDecoder with deepseek_v4 block and scan_layers=True.""" + cfg = _make_config( + decoder_block="deepseek_v4", + scan_layers=True, + num_decoder_layers=3, + q_lora_rank=1024, + o_lora_rank=1024, + qk_rope_head_dim=64, + compress_ratios=[4, 128, 4], + base_moe_mlp_dim=512, + shared_experts=2, + mhc_expansion_rate=4, + megablox=False, # Disable custom Pallas GMM TPU kernels on CPU testing platforms! + ) + decoder = NNXDecoder( + config=cfg, + mesh=self.mesh, + model_mode=MODEL_MODE_TRAIN, + rngs=self.rngs, + ) + shared_embedding = self._make_shared_embedding(cfg) + ids, segment_ids, positions = self._make_token_inputs(cfg) + + logits, _, _ = decoder( + shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertEqual( + logits.shape, + (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size), + )