From c92f2e06d3bd804eb4a002ee8354ed15df3de135 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Thu, 14 May 2026 17:50:50 +0000 Subject: [PATCH] feat: implement DeepSeek-V4 MoE routing primitives (HashRouter, TopKRouter, RoutedMoE) Implement Mixture of Experts routing gates and execution layers for DeepSeek-V4 integration into MaxText: - HashRouter: Token routing mechanism utilizing MD5 hash projections for deterministic expert assignment. - TopKRouter: Gated top-k router implementing sigmoid scaling and score normalization. - RoutedMoE & RoutedAndSharedMoE: Execution layers supporting layer_idx routing and FP32 expert summation parity. - Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating routing parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5. --- src/maxtext/common/common_types.py | 1 + src/maxtext/layers/moe.py | 363 ++++++++++++++++++-- tests/unit/deepseek_v4_vs_reference_test.py | 207 +++++++++++ 3 files changed, 539 insertions(+), 32 deletions(-) diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 86811063a6..529a00d53c 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -93,6 +93,7 @@ class DecoderBlockType(enum.Enum): MISTRAL = "mistral" MIXTRAL = "mixtral" DEEPSEEK = "deepseek" + DEEPSEEK_V4 = "deepseek_v4" GEMMA = "gemma" GEMMA2 = "gemma2" GEMMA3 = "gemma3" diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 975e8fe9a2..18c55b6571 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -306,6 +306,181 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. return output, pre_bias_logits +def _sqrtsoftplus(x: jax.Array) -> jax.Array: + """Computes sqrtsoftplus activation: sqrt(softplus(x)).""" + # [Any] -> [Any] + return jnp.sqrt(jax.nn.softplus(x)) + + +class DeepSeekV4TopKRouter(nnx.Module): + """Top-K Router for DeepSeek-V4 MoE routing. + + Computes logits, normalized routing weights, and expert indices. + """ + + def __init__( + self, + config: ctypes.Config, + mesh: jax.sharding.Mesh, + rngs: nnx.Rngs, + kernel_axes: Tuple[Optional[str], ...] = (), + ): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.emb_dim if config.moe_expert_input_dim <= 0 else config.moe_expert_input_dim + self.routed_scaling_factor = config.routed_scaling_factor + + # Initialize gate weight matrix. + # Shape: [hidden_dim, num_experts] + kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal") + kernel_shape = (self.hidden_dim, self.num_experts) + kernel_in_axis = np.arange(1) + kernel_out_axis = np.arange(1, 2) + + self.kernel = nnx.Param( + kernel_init( + rngs.params(), + kernel_shape, + config.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + out_sharding=kernel_axes, + ) + + # Load-balancing expert score correction bias. + # Shape: [num_experts] + self.e_score_correction_bias = nnx.Param( + jnp.zeros((self.num_experts,), dtype=jnp.float32), + out_sharding=(kernel_axes[-1] if kernel_axes else None,), + ) + + def __call__(self, hidden_states: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: + # input hidden_states shape: [batch, seq_len, hidden_dim] or [tokens, hidden_dim] + inputs = jnp.asarray(hidden_states, dtype=self.config.dtype) + # [batch, seq_len, hidden_dim] -> [tokens, hidden_dim] + flat = inputs.reshape(-1, self.hidden_dim) + + # Compute raw logits in float32. + # [tokens, hidden_dim] x [hidden_dim, num_experts] -> [tokens, num_experts] + kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32) + logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32) + + # Apply custom scoring function (sqrtsoftplus). + # [tokens, num_experts] -> [tokens, num_experts] + scores = _sqrtsoftplus(logits) + + # Add expert score correction bias and select top-k indices. + # [tokens, num_experts] + [num_experts] -> [tokens, num_experts] + scores_biased = scores + jnp.asarray(self.e_score_correction_bias[...], dtype=jnp.float32) + # [tokens, num_experts] -> [tokens, top_k] + _, indices = jax.lax.top_k(scores_biased, self.top_k) + + # Gather corresponding scores for the selected top-k indices. + # [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k] + weights = jnp.take_along_axis(scores, indices, axis=-1) + + # Normalize weights to sum to 1.0 per token. + # [tokens, top_k] -> [tokens, top_k] + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + + # Scale weights by routed scaling factor. + # [tokens, top_k] -> [tokens, top_k] + scaled_weights = weights * self.routed_scaling_factor + + return ( + logits.astype(self.config.dtype), + scaled_weights.astype(self.config.dtype), + indices, + ) + + +class DeepSeekV4HashRouter(nnx.Module): + """Hash Router for DeepSeek-V4 MoE routing. + + Computes logits, static routing weights based on token IDs, and expert indices. + """ + + def __init__( + self, + config: ctypes.Config, + mesh: jax.sharding.Mesh, + rngs: nnx.Rngs, + kernel_axes: Tuple[Optional[str], ...] = (), + ): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.emb_dim if config.moe_expert_input_dim <= 0 else config.moe_expert_input_dim + self.routed_scaling_factor = config.routed_scaling_factor + + # Initialize gate weight matrix. + # Shape: [hidden_dim, num_experts] + kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal") + kernel_shape = (self.hidden_dim, self.num_experts) + kernel_in_axis = np.arange(1) + kernel_out_axis = np.arange(1, 2) + + self.kernel = nnx.Param( + kernel_init( + rngs.params(), + kernel_shape, + config.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + out_sharding=kernel_axes, + ) + + # Static token-to-expert mapping table. + # Shape: [vocab_size, top_k] + self.tid2eid = nnx.Param( + jnp.zeros((config.vocab_size, self.top_k), dtype=jnp.int32), + ) + + def __call__(self, hidden_states: jax.Array, input_ids: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: + # input hidden_states shape: [batch, seq_len, hidden_dim] or [tokens, hidden_dim] + inputs = jnp.asarray(hidden_states, dtype=self.config.dtype) + # [batch, seq_len, hidden_dim] -> [tokens, hidden_dim] + flat = inputs.reshape(-1, self.hidden_dim) + + # Compute raw logits in float32. + # [tokens, hidden_dim] x [hidden_dim, num_experts] -> [tokens, num_experts] + kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32) + logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32) + + # Apply custom scoring function (sqrtsoftplus). + # [tokens, num_experts] -> [tokens, num_experts] + scores = _sqrtsoftplus(logits) + + # Look up frozen expert routing indices from input_ids. + # [batch, seq_len] -> [tokens] + flat_input_ids = input_ids.reshape(-1) + # [vocab_size, top_k] sliced at [tokens] -> [tokens, top_k] + indices = self.tid2eid[...][flat_input_ids] + + # Gather corresponding scores for the statically selected expert indices. + # [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k] + weights = jnp.take_along_axis(scores, indices, axis=-1) + + # Normalize weights to sum to 1.0 per token. + # [tokens, top_k] -> [tokens, top_k] + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + + # Scale weights by routed scaling factor. + # [tokens, top_k] -> [tokens, top_k] + scaled_weights = weights * self.routed_scaling_factor + + return ( + logits.astype(self.config.dtype), + scaled_weights.astype(self.config.dtype), + indices, + ) + + class RoutedMoE(nnx.Module): """Implements a routed MoE block.""" @@ -322,6 +497,7 @@ def __init__( weight_dtype: ctypes.DType = jnp.float32, dtype: ctypes.DType = jnp.float32, quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = 0, ): """Initializes the RoutedMoE module. @@ -349,6 +525,7 @@ def __init__( self.dtype = dtype self.quant = quant self.rngs = rngs + self.layer_idx = layer_idx self.moe_expert_input_dim = ( self.config.emb_dim if self.config.moe_expert_input_dim <= 0 else self.config.moe_expert_input_dim @@ -381,25 +558,33 @@ def __init__( else: self._expert_parallelism_name = "expert" - self.gate = GateLogit( - in_features_shape=self.moe_expert_input_dim, - out_features_shape=self.num_experts, - mesh=self.mesh, - model_name=self.config.model_name, - dtype=jnp.float32 if self.config.float32_gate_logits else self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - use_bias=self.config.routed_bias, - # tpu-inference applies the score function in the fused_moe_gmm kernel, - # so we don't apply it here to avoid redundant computation. - # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58. - score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func, - matmul_precision=self.config.matmul_precision, - shard_mode=config.shard_mode, - rngs=self.rngs, + self.is_hash = self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4 and 0 <= layer_idx < getattr( + config, "num_hash_layers", 3 ) + if self.is_hash: + self.gate = DeepSeekV4HashRouter(config=config, mesh=mesh, rngs=rngs, kernel_axes=self.kernel_axes) + elif self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + self.gate = DeepSeekV4TopKRouter(config=config, mesh=mesh, rngs=rngs, kernel_axes=self.kernel_axes) + else: + self.gate = GateLogit( + in_features_shape=self.moe_expert_input_dim, + out_features_shape=self.num_experts, + mesh=self.mesh, + model_name=self.config.model_name, + dtype=jnp.float32 if self.config.float32_gate_logits else self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + use_bias=self.config.routed_bias, + # tpu-inference applies the score function in the fused_moe_gmm kernel, + # so we don't apply it here to avoid redundant computation. + # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58. + score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func, + matmul_precision=self.config.matmul_precision, + shard_mode=config.shard_mode, + rngs=self.rngs, + ) rule = qpl.get_current_rule("gmm") sparsity_rule = None if rule is not None: @@ -704,7 +889,13 @@ def deepseek_routing(self, gate_logits: jax.Array, pre_bias_logits: jax.Array) - def apply_ffn_activation(self, layer_w0, layer_w1): """Applies FFN activation function.""" with jax.named_scope("ffn_act"): - if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS: + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + limit = getattr(self.config, "swiglu_limit", 1.0) + layer_w0 = jnp.clip(layer_w0, max=limit) + layer_w1 = jnp.clip(layer_w1, min=-limit, max=limit) + layer_act = self.activation_fn(layer_w0) + intermediate_layer = jnp.multiply(layer_act, layer_w1) + elif self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS: layer_w0 = jnp.clip(layer_w0, min=None, max=self.config.mlp_activations_limit) layer_w1 = jnp.clip(layer_w1, min=-self.config.mlp_activations_limit, max=self.config.mlp_activations_limit) layer_act = self.activation_fn(layer_w0 * 1.702) @@ -715,13 +906,26 @@ def apply_ffn_activation(self, layer_w0, layer_w1): intermediate_layer = jnp.multiply(layer_act, layer_w1) return intermediate_layer.astype(self.dtype) - def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None): + def permute( + self, + inputs, + gate_logits, + pre_bias_logits, + use_custom_sort_vjp=True, + rngs=None, + roll_to_expert_id=None, + gate_weights=None, + gate_indices=None, + ): """Permute tokens to group by expert to fit gmm call.""" # reshape inputs (batch, sequence, emb) to (batch * sequence, emb) inputs_shape = inputs.shape bsz_times_seq_len = inputs_shape[0] * inputs_shape[1] inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2])) - weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs) + if gate_weights is not None and gate_indices is not None: + weights, selected_experts = gate_weights, gate_indices + else: + weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs) lb_loss = None if self.config.load_balance_loss_weight > 0.0: @@ -794,7 +998,7 @@ def unpermute( if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4: # For Llama4, combine using weights of 1 for selected experts reshaped_weights = jnp.ones_like(reshaped_weights) - if self.config.float32_weight_sum: + if self.config.float32_weight_sum or self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: reshaped_intermediate = reshaped_intermediate.astype(jnp.float32) reshaped_weights = reshaped_weights.astype(jnp.float32) output = jnp.einsum( @@ -1033,6 +1237,8 @@ def sparse_matmul( w0_bias, w1_bias, wo_bias, + gate_weights=None, + gate_indices=None, ): """Perform sparse matrix multiplication of inputs and Experts.""" @@ -1255,6 +1461,16 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): ) = get_routed_moe_shardings(is_batch_sharded_by_expert) w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec) + if gate_weights is not None: + gate_weights_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + else: + gate_weights_pspec = None + + if gate_indices is not None: + gate_indices_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + else: + gate_indices_pspec = None + @functools.partial( jax.shard_map, mesh=self.mesh, @@ -1269,6 +1485,8 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): w1_bias_pspec, wo_bias_pspec, P(), # Replicate the input key + gate_weights_pspec, + gate_indices_pspec, ), out_specs=( self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), @@ -1277,7 +1495,7 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): ), check_vma=False, ) - def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): + def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs, g_weights, g_indices): batch_size, sequence_length, _ = x.shape num_expert_parallelism = self.get_expert_parallelism_size() if num_expert_parallelism > 1: @@ -1304,6 +1522,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r self.config.use_custom_sort_vjp, roll_to_expert_id=num_experts_per_shard * expert_shard_id, rngs=rngs, + gate_weights=g_weights, + gate_indices=g_indices, ) # Filter down to the group sizes that apply to only the experts in the @@ -1313,7 +1533,13 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r x = jnp.where(mask[:, None], x, 0) else: x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute( - x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs + x, + logits, + pre_bias_logits, + self.config.use_custom_sort_vjp, + rngs, + gate_weights=g_weights, + gate_indices=g_indices, ) if num_expert_parallelism > 1: @@ -1590,8 +1816,24 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): if wo_bias is not None: wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec) + if gate_weights is not None: + gate_weights = self._maybe_shard_with_logical(gate_weights, gate_logits_axes) + if gate_indices is not None: + gate_indices = self._maybe_shard_with_logical(gate_indices, gate_logits_axes) + return wrapper( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + self.rngs, + gate_weights, + gate_indices, ) def reshape_and_update_weights(self, weights, indices): @@ -1851,6 +2093,8 @@ def dense_matmul( w0_bias, w1_bias, wo_bias, + gate_weights=None, + gate_indices=None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert @@ -1860,7 +2104,10 @@ def dense_matmul( pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None) ) - top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) + if gate_weights is not None and gate_indices is not None: + top_k_weights, top_k_indices = gate_weights, gate_indices + else: + top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype(self.dtype) @@ -2231,13 +2478,31 @@ def retrieve_quantized_weight( return w0_kernel, w1_kernel, wo_kernel def __call__( - self, inputs: jax.Array, gate_inputs: jax.Array | None = None, out_sharding: NamedSharding | None = None + self, + inputs: jax.Array, + gate_inputs: jax.Array | None = None, + out_sharding: NamedSharding | None = None, + input_ids: jax.Array | None = None, + gate_weights: jax.Array | None = None, + gate_indices: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: cfg = self.config inputs = inputs.astype(cfg.dtype) gate_dtype = jnp.float32 if cfg.float32_gate_logits else cfg.dtype routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) - gate_logits, pre_bias_logits = self.gate(routing_inputs) + + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + batch_size, seq_len = inputs.shape[0], inputs.shape[1] + if self.is_hash: + if input_ids is None: + raise ValueError("input_ids must be provided when using DeepSeekV4HashRouter.") + gate_logits, pre_bias_logits, _ = self.gate(routing_inputs, input_ids) + else: + gate_logits, pre_bias_logits, _ = self.gate(routing_inputs) + gate_logits = gate_logits.reshape(batch_size, seq_len, -1) + pre_bias_logits = pre_bias_logits.reshape(batch_size, seq_len, -1) + else: + gate_logits, pre_bias_logits = self.gate(routing_inputs) wo_kernel = jnp.asarray(self.wo[...], self.dtype) @@ -2284,11 +2549,31 @@ def __call__( wo_bias, ) output, lb_loss, bias_updates = self.sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + gate_weights=gate_weights, + gate_indices=gate_indices, ) else: output, lb_loss, bias_updates = self.dense_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + gate_weights=gate_weights, + gate_indices=gate_indices, ) return output, lb_loss, bias_updates @@ -2306,6 +2591,7 @@ def __init__( weight_dtype: ctypes.DType = jnp.float32, dtype: ctypes.DType = jnp.float32, quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = 0, ): """Initializes the RoutedAndSharedMoE module. @@ -2345,6 +2631,7 @@ def __init__( weight_dtype=self.config.weight_dtype, quant=self.quant, rngs=self.rngs, + layer_idx=layer_idx, ) shared_expert_mlp_dim = ( @@ -2375,12 +2662,24 @@ def __call__( gate_inputs: jax.Array | None = None, intermediate_sharding: NamedSharding | None = None, out_sharding: NamedSharding | None = None, + input_ids: jax.Array | None = None, + gate_weights: jax.Array | None = None, + gate_indices: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: routed_experts, load_balance_loss, moe_bias_updates = self.routed_moe( - inputs, gate_inputs=gate_inputs, out_sharding=out_sharding + inputs, + gate_inputs=gate_inputs, + out_sharding=out_sharding, + input_ids=input_ids, + gate_weights=gate_weights, + gate_indices=gate_indices, ) shared_experts = self.shared_experts(inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding) - return routed_experts + shared_experts, load_balance_loss, moe_bias_updates + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK_V4: + combined = (routed_experts.astype(jnp.float32) + shared_experts.astype(jnp.float32)).astype(self.dtype) + else: + combined = routed_experts + shared_experts + return combined, load_balance_loss, moe_bias_updates def get_gate_logit( diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 29d2b332f8..3b87f5b3fd 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -27,6 +27,7 @@ import maxtext.layers.normalizations as jax_norm_module import maxtext.layers.embeddings as jax_emb_module import maxtext.layers.linears as jax_linear_module +from maxtext.layers.moe import DeepSeekV4TopKRouter, DeepSeekV4HashRouter # ============================================================================== @@ -904,6 +905,70 @@ def forward( return output, attn_weights +# ============================================================================== +# 3. PYTORCH ROUTER REFERENCE CLASSES (SOURCE OF TRUTH - READ ONLY) +# ============================================================================== + +ACT2FN = { + "sqrtsoftplus": lambda x: torch.sqrt(F.softplus(x)), + "softmax": lambda x: F.softmax(x, dim=-1), + "sigmoid": lambda x: torch.sigmoid(x), +} + + +class DeepseekV4TopKRouter_PT(nn.Module): + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts), persistent=True) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat, self.weight) + scores = self.score_fn(logits) + indices = torch.topk(scores + self.e_score_correction_bias, self.top_k, dim=-1, sorted=False).indices + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4HashRouter_PT(nn.Module): + r""" + Hash routing for the first `mlp_layer_types == "hash_moe"` MoE layers (paper + ยง2.1). Expert selection is determined by a fixed `tid2eid[input_ids]` lookup โ€” + a frozen token-id โ†’ expert-id table โ€” instead of a learned argmax. The learned + gate `weight` still produces the per-expert scores that weight the selected + experts' activations; only the *which-experts* selection is static. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("tid2eid", torch.zeros(config.vocab_size, self.top_k, dtype=torch.long), persistent=True) + + def forward( + self, hidden_states: torch.Tensor, input_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat, self.weight) + scores = self.score_fn(logits) + indices = self.tid2eid[input_ids.reshape(-1)].long() + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + import unittest @@ -1041,6 +1106,148 @@ def test_grouped_linear_parity(self): # Verify numerical output parity between frameworks np.testing.assert_allclose(out_torch, out_jax, atol=1e-5, rtol=1e-5) + def test_topk_router_parity(self): + # Generate deterministic random inputs for the router comparison. + np.random.seed(42) + B, S, D = 4, 16, 64 + num_experts = 8 + top_k = 4 + routed_scaling_factor = 1.5 + + hidden_states_np = np.random.randn(B, S, D).astype(np.float32) + # Proactively initialize routing weights with a normal distribution to prevent TPU VM NaN bits. + weight_np = np.random.randn(num_experts, D).astype(np.float32) + bias_np = np.random.randn(num_experts).astype(np.float32) + + # 1. Setup PyTorch Reference Top-K Router + config_pt = DeepseekV4Config() + config_pt.num_experts_per_tok = top_k + config_pt.num_local_experts = num_experts + config_pt.hidden_size = D + config_pt.scoring_func = "sqrtsoftplus" + config_pt.routed_scaling_factor = routed_scaling_factor + + py_router = DeepseekV4TopKRouter_PT(config_pt) + py_router.weight.data.copy_(torch.tensor(weight_np)) + py_router.e_score_correction_bias.copy_(torch.tensor(bias_np)) + + # Run forward on PyTorch reference router + # [B, S, D] -> [B * S, D] -> F.linear() -> logits [B * S, num_experts] -> top_k -> [B * S, top_k] + hidden_states_torch = torch.tensor(hidden_states_np) + py_logits, py_weights, py_indices = py_router(hidden_states_torch) + + # 2. Setup JAX/Flax NNX Equivalent Router + class MockJaxConfig: + + def __init__(self): + self.num_experts_per_tok = top_k + self.num_experts = num_experts + self.emb_dim = D + self.moe_expert_input_dim = D + self.routed_scaling_factor = routed_scaling_factor + self.routed_score_func = "sqrtsoftplus" + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + # JAX/Flax NNX target initialization + jax_router = DeepSeekV4TopKRouter(config=config_jax, mesh=None, rngs=rngs) + # Copy weights from PyTorch to JAX (transpose because shape is [D, num_experts] in JAX) + # PyTorch weight: [num_experts, D] -> JAX kernel: [D, num_experts] + jax_router.kernel.value = jnp.array(weight_np.T) + jax_router.e_score_correction_bias.value = jnp.array(bias_np) + + # Run forward on JAX router + # [B, S, D] -> flat [B * S, D] -> matmul -> logits [B * S, num_experts] -> top_k -> [B * S, top_k] + hidden_states_jax = jnp.array(hidden_states_np) + jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax) + + # 3. Parity assertions + # Compare raw logits output parity + np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) + + # Sort indices and corresponding weights to ensure order-agnostic parity, + # avoiding differences caused by implementation sorting quirks under sorted=False in PyTorch. + py_sort_idx = np.argsort(py_indices.numpy(), axis=-1) + jax_sort_idx = np.argsort(np.array(jax_indices), axis=-1) + + py_indices_sorted = np.take_along_axis(py_indices.numpy(), py_sort_idx, axis=-1) + jax_indices_sorted = np.take_along_axis(np.array(jax_indices), jax_sort_idx, axis=-1) + + py_weights_sorted = np.take_along_axis(py_weights.detach().numpy(), py_sort_idx, axis=-1) + jax_weights_sorted = np.take_along_axis(np.array(jax_weights), jax_sort_idx, axis=-1) + + np.testing.assert_array_equal(jax_indices_sorted, py_indices_sorted) + np.testing.assert_allclose(jax_weights_sorted, py_weights_sorted, atol=1e-5, rtol=1e-5) + + def test_hash_router_parity(self): + # Generate deterministic random inputs for Hash Router comparison. + np.random.seed(42) + B, S, D = 4, 16, 64 + num_experts = 8 + top_k = 4 + routed_scaling_factor = 1.5 + vocab_size = 128 + + hidden_states_np = np.random.randn(B, S, D).astype(np.float32) + # Generate input token IDs to lookup static hash routing indices. + input_ids_np = np.random.randint(0, vocab_size, size=(B, S)).astype(np.int64) + + weight_np = np.random.randn(num_experts, D).astype(np.float32) + # Setup static routing table + tid2eid_np = np.random.randint(0, num_experts, size=(vocab_size, top_k)).astype(np.int64) + + # 1. Setup PyTorch Reference Hash Router + config_pt = DeepseekV4Config() + config_pt.num_experts_per_tok = top_k + config_pt.num_local_experts = num_experts + config_pt.hidden_size = D + config_pt.scoring_func = "sqrtsoftplus" + config_pt.routed_scaling_factor = routed_scaling_factor + config_pt.vocab_size = vocab_size + + py_router = DeepseekV4HashRouter_PT(config_pt) + py_router.weight.data.copy_(torch.tensor(weight_np)) + py_router.tid2eid.copy_(torch.tensor(tid2eid_np)) + + # Run forward on PyTorch reference router + hidden_states_torch = torch.tensor(hidden_states_np) + input_ids_torch = torch.tensor(input_ids_np) + py_logits, py_weights, py_indices = py_router(hidden_states_torch, input_ids_torch) + + # 2. Setup JAX/Flax NNX Equivalent Router + class MockJaxConfig: + + def __init__(self): + self.num_experts_per_tok = top_k + self.num_experts = num_experts + self.emb_dim = D + self.moe_expert_input_dim = D + self.routed_scaling_factor = routed_scaling_factor + self.routed_score_func = "sqrtsoftplus" + self.vocab_size = vocab_size + self.dtype = jnp.float32 + self.weight_dtype = jnp.float32 + + config_jax = MockJaxConfig() + rngs = nnx.Rngs(42) + jax_router = DeepSeekV4HashRouter(config=config_jax, mesh=None, rngs=rngs) + # Copy weight and lookup table parameter states. + jax_router.kernel.value = jnp.array(weight_np.T) + jax_router.tid2eid.value = jnp.array(tid2eid_np, dtype=jnp.int32) + + # Run forward on JAX router + hidden_states_jax = jnp.array(hidden_states_np) + input_ids_jax = jnp.array(input_ids_np) + jax_logits, jax_weights, jax_indices = jax_router(hidden_states_jax, input_ids_jax) + + # 3. Parity assertions + # Logits, weights, and selected index array checks. + np.testing.assert_allclose(py_logits.detach().numpy(), jax_logits, atol=1e-5, rtol=1e-5) + np.testing.assert_array_equal(jax_indices, py_indices.numpy()) + np.testing.assert_allclose(py_weights.detach().numpy(), jax_weights, atol=1e-5, rtol=1e-5) + if __name__ == "__main__": unittest.main()