From 13280fd97ada27f21c40417873f6fe7564c76b74 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 12 May 2026 04:05:25 +0000 Subject: [PATCH] Plumbing and core MoE logic for router replay --- .../integration/tunix/tunix_adapter.py | 2 + .../vllm/maxtext_vllm_adapter/adapter.py | 2 +- src/maxtext/layers/decoders.py | 55 ++++-- src/maxtext/layers/moe.py | 170 ++++++++++++++---- src/maxtext/layers/nnx_decoders.py | 41 ++++- src/maxtext/layers/nnx_wrappers.py | 20 +++ src/maxtext/models/deepseek.py | 13 +- src/maxtext/models/gemma4.py | 8 +- src/maxtext/models/mixtral.py | 3 +- src/maxtext/models/models.py | 6 + src/maxtext/models/qwen3.py | 14 +- src/maxtext/models/qwen3_5.py | 5 +- tests/unit/forced_routing_test.py | 109 +++++++++++ 13 files changed, 385 insertions(+), 63 deletions(-) create mode 100644 tests/unit/forced_routing_test.py diff --git a/src/maxtext/integration/tunix/tunix_adapter.py b/src/maxtext/integration/tunix/tunix_adapter.py index d509e512a1..890b0bac9f 100644 --- a/src/maxtext/integration/tunix/tunix_adapter.py +++ b/src/maxtext/integration/tunix/tunix_adapter.py @@ -59,6 +59,7 @@ def __call__( attention_mask: Optional[Array], # [B, L, L] or None decoder_segment_ids: Optional[Array] = None, output_hidden_states: bool = False, # ignored + forced_routed_experts: Optional[Array] = None, ) -> Tuple[Array, None]: """Forward compatible with Tunix Trainers default loss. Returns logits, None. @@ -67,6 +68,7 @@ def __call__( decoder_input_tokens=input_tokens, decoder_positions=positions, decoder_segment_ids=decoder_segment_ids, + forced_routed_experts=forced_routed_experts, ) return logits, None diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 9fa42d6abe..07231f965e 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None: model = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) - self.model = nnx.data(model) \ No newline at end of file + self.model = nnx.data(model) diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 56374f676e..92dbaab258 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -775,11 +775,15 @@ def __call__( kv_caches: list[jax.Array] | None = None, attention_metadata=None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + forced_routed_experts: jnp.ndarray | None = None, ): cfg = self.config mesh = self.mesh assert decoder_input_tokens.ndim == 2 # [batch, len] + if cfg.scan_layers and forced_routed_experts is not None: + raise NotImplementedError("Forced routing with scanned layers is not supported yet.") + # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( shared_embedding, @@ -1061,6 +1065,10 @@ def __call__( global_layer_idx = global_layer_idx_offset + index kv_cache = kv_caches[index] if kv_caches is not None else None input_tokens = decoder_input_tokens if cfg.engram_layers else None + current_forced_routed_experts = None + if forced_routed_experts is not None and layer_prefix == "moe_layers": + current_forced_routed_experts = forced_routed_experts[:, :, index, :] + y, kv_cache = layer( config=cfg, mesh=mesh, @@ -1080,11 +1088,13 @@ def __call__( kv_cache=kv_cache, attention_metadata=attention_metadata, decoder_input_tokens=input_tokens, + forced_routed_experts=current_forced_routed_experts, ) if kv_caches is not None and kv_cache is not None: kv_caches[index] = kv_cache global_layer_idx_offset += num_layers else: + moe_lyr_idx = 0 for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0] layer_kwargs = {} @@ -1121,19 +1131,38 @@ def __call__( layer = RemattedBlockLayer( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) - y, returned_cache = layer( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - **layer_call_kwargs, - ) + current_forced_routed_experts = None + is_moe = False + if cfg.decoder_block in ( + DecoderBlockType.MIXTRAL, + DecoderBlockType.QWEN3_MOE, + DecoderBlockType.QWEN3_NEXT, + DecoderBlockType.QWEN3_5, + DecoderBlockType.QWEN3_CUSTOM_MOE, + ): + is_moe = True + elif cfg.decoder_block == DecoderBlockType.LLAMA4: + is_moe = llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step) + + if is_moe and forced_routed_experts is not None: + current_forced_routed_experts = forced_routed_experts[:, :, moe_lyr_idx, :] + moe_lyr_idx += 1 + elif is_moe: + moe_lyr_idx += 1 + + call_kwargs = { + "previous_chunk": previous_chunk, + "page_state": page_state, + "slot": slot, + "kv_cache": kv_cache, + "attention_metadata": attention_metadata, + } + call_kwargs.update(layer_call_kwargs) + + if is_moe and current_forced_routed_experts is not None: + call_kwargs["forced_routed_experts"] = current_forced_routed_experts + + y, returned_cache = layer(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, **call_kwargs) if kv_caches is not None and returned_cache is not None: if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): kv_caches[lyr] = returned_cache diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 975e8fe9a2..9040fd9ff2 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -595,32 +595,47 @@ def should_update_load_balance(self): """Determines if loss-free load balancing updates should be applied.""" return self.config.routed_bias and self.config.routed_bias_update_rate > 0.0 - def get_topk(self, gate_logits, pre_bias_logits, rngs=None): + def get_topk(self, gate_logits, pre_bias_logits, rngs=None, forced_routed_experts=None): """get topk.""" # shape of top_k_weights & top_k_indices: # (batch, sequence, num_experts_per_tok). - if self.config.use_random_routing: - if rngs is None: - raise ValueError("The random key cannot be None for random routing.") - # Reuse the 'params' RNG stream to ensure random routing - rng = rngs.params() - top_k_weights, top_k_indices = random_routing(rng, gate_logits, self.num_experts_per_tok) - return top_k_weights, top_k_indices - - if self.config.model_name.startswith("deepseek3"): - top_k_weights, top_k_indices = self.deepseek_routing(gate_logits, pre_bias_logits) - elif self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4: - router_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1) - _, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) - top_k_weights = jnp.take_along_axis(router_probs, top_k_indices, axis=-1).astype(self.dtype) + if forced_routed_experts is not None: + top_k_indices = forced_routed_experts + if self.config.model_name.startswith("deepseek3"): + top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1) + elif self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4: + router_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1) + top_k_weights = jnp.take_along_axis(router_probs, top_k_indices, axis=-1).astype(self.dtype) + else: + top_k_weights = jnp.take_along_axis(gate_logits, top_k_indices, axis=-1) else: - top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + if self.config.use_random_routing: + if rngs is None: + raise ValueError("The random key cannot be None for random routing.") + # Reuse the 'params' RNG stream to ensure random routing + rng = rngs.params() + top_k_weights, top_k_indices = random_routing(rng, gate_logits, self.num_experts_per_tok) + return top_k_weights, top_k_indices + + if self.config.model_name.startswith("deepseek3"): + top_k_weights, top_k_indices = self.deepseek_routing(gate_logits, pre_bias_logits) + elif self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4: + router_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1) + _, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + top_k_weights = jnp.take_along_axis(router_probs, top_k_indices, axis=-1).astype(self.dtype) + else: + top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK: top_k_weights = self.deepseek_scale_weights(top_k_weights) elif self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4): top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype) + # Zero out weights for padding indices! + if forced_routed_experts is not None: + valid_token_mask = top_k_indices[:, :, 0] != -1 + top_k_weights = top_k_weights * valid_token_mask[:, :, None] + # Normalization of router weights (e.g. used by Qwen3, Gemma4). if self.config.norm_topk_prob: top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True) @@ -715,13 +730,22 @@ def apply_ffn_activation(self, layer_w0, layer_w1): intermediate_layer = jnp.multiply(layer_act, layer_w1) return intermediate_layer.astype(self.dtype) - def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None): + def permute( + self, + inputs, + gate_logits, + pre_bias_logits, + use_custom_sort_vjp=True, + rngs=None, + roll_to_expert_id=None, + forced_routed_experts=None, + ): """Permute tokens to group by expert to fit gmm call.""" # reshape inputs (batch, sequence, emb) to (batch * sequence, emb) inputs_shape = inputs.shape bsz_times_seq_len = inputs_shape[0] * inputs_shape[1] inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2])) - weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs) + weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs, forced_routed_experts) lb_loss = None if self.config.load_balance_loss_weight > 0.0: @@ -744,13 +768,22 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True flatten_selected_experts = jnp.ravel(selected_experts) if roll_to_expert_id is not None: flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % self.num_experts - sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + if forced_routed_experts is not None: + # Fix padding bug: map -1 to dummy valid indices to distribute load + valid_mask = flatten_selected_experts >= 0 + dummy_indices = jnp.arange(flatten_selected_experts.shape[0]) % self.num_experts + flatten_selected_experts_safe = jnp.where(valid_mask, flatten_selected_experts, dummy_indices) + else: + flatten_selected_experts_safe = flatten_selected_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts_safe) # sort inputs for number of selected experts replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0) sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype( self.dtype ) - group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts) + group_size = jnp.bincount(flatten_selected_experts_safe, length=self.num_experts) # Return the experts for each sorted input. expert_indices = jnp.arange(self.num_experts) sorted_experts = jnp.repeat( @@ -1033,6 +1066,7 @@ def sparse_matmul( w0_bias, w1_bias, wo_bias, + forced_routed_experts=None, ): """Perform sparse matrix multiplication of inputs and Experts.""" @@ -1269,6 +1303,9 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): w1_bias_pspec, wo_bias_pspec, P(), # Replicate the input key + self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + if forced_routed_experts is not None + else None, ), out_specs=( self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), @@ -1277,7 +1314,7 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): ), check_vma=False, ) - def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): + def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs, forced_routed_experts): batch_size, sequence_length, _ = x.shape num_expert_parallelism = self.get_expert_parallelism_size() if num_expert_parallelism > 1: @@ -1304,6 +1341,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r self.config.use_custom_sort_vjp, roll_to_expert_id=num_experts_per_shard * expert_shard_id, rngs=rngs, + forced_routed_experts=forced_routed_experts, ) # Filter down to the group sizes that apply to only the experts in the @@ -1313,7 +1351,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r x = jnp.where(mask[:, None], x, 0) else: x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute( - x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs + x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs, forced_routed_experts=forced_routed_experts ) if num_expert_parallelism > 1: @@ -1579,6 +1617,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): inputs = self._maybe_shard_with_logical(inputs, input_axes) gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) + forced_routed_experts = ( + self._maybe_shard_with_logical(forced_routed_experts, (batch_logical_axis, "activation_norm_length", None)) + if forced_routed_experts is not None + else None + ) w0_kernel = self._maybe_shard_with_pspec(w0_kernel, w0_pspec) w1_kernel = self._maybe_shard_with_pspec(w1_kernel, w1_pspec) @@ -1591,25 +1634,43 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec) return wrapper( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + self.rngs, + forced_routed_experts, ) - def reshape_and_update_weights(self, weights, indices): + def reshape_and_update_weights(self, weights, indices, safe_updates=False): """reshape and update weights.""" # input of weights and indices: (batch_size, seq_len, num_experts_per_tok) # output of updated weights: (batch_size, seq_len, num_experts) update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) + if safe_updates: + valid_mask = indices >= 0 + safe_indices = jnp.where(valid_mask, indices, 0) + safe_weights = jnp.where(valid_mask, weights, 0.0) + else: + safe_indices = indices + safe_weights = weights + index_update = ( self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch", None, None)), self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length", None)), - indices, + safe_indices, ) weight_sharding = ( create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.config.shard_mode == ShardMode.EXPLICIT else None ) - update_weights = update_weights.at[index_update].set(weights, out_sharding=weight_sharding) + update_weights = update_weights.at[index_update].set(safe_weights, out_sharding=weight_sharding) return update_weights def get_context_partition_and_sub_seq(self, seq_len): @@ -1851,6 +1912,7 @@ def dense_matmul( w0_bias, w1_bias, wo_bias, + forced_routed_experts=None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert @@ -1860,13 +1922,19 @@ def dense_matmul( pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None) ) - top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) + if forced_routed_experts is not None: + forced_routed_experts = self._maybe_shard_with_logical( + forced_routed_experts, ("activation_batch_moe", "activation_length_moe", None) + ) + top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs, forced_routed_experts) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype(self.dtype) inputs = inputs * router_scores else: - weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) + weights = self.reshape_and_update_weights( + top_k_weights, top_k_indices, safe_updates=(forced_routed_experts is not None) + ) matmul_precision = jax.lax.Precision(self.config.matmul_precision) # Calculate load balance loss @@ -2122,7 +2190,9 @@ def dense_matmul( intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo") with jax.named_scope("weight_sum"): if is_llama4_decoder_layer: - weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices) + weights = self.reshape_and_update_weights( + jnp.ones_like(top_k_weights), top_k_indices, safe_updates=(forced_routed_experts is not None) + ) if self.config.float32_weight_sum: intermediate_layer = intermediate_layer.astype(jnp.float32) weights = weights.astype(jnp.float32) @@ -2143,12 +2213,15 @@ def fused_moe_matmul( w0_kernel=None, w1_kernel=None, fused_kernel=None, + forced_routed_experts=None, ) -> tuple[jax.Array, None, None]: """Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only). fused_moe_func handles routing, GMM, and weighted combination internally. It does not compute lb_loss or bias_updates (inference-only). """ + if forced_routed_experts is not None: + raise NotImplementedError("Forced routing via forced_routed_experts is not supported with fused_moe_matmul.") try: # pylint: disable=import-outside-toplevel # pytype: disable=import-error @@ -2231,7 +2304,11 @@ def retrieve_quantized_weight( return w0_kernel, w1_kernel, wo_kernel def __call__( - self, inputs: jax.Array, gate_inputs: jax.Array | None = None, out_sharding: NamedSharding | None = None + self, + inputs: jax.Array, + gate_inputs: jax.Array | None = None, + out_sharding: NamedSharding | None = None, + forced_routed_experts: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: cfg = self.config inputs = inputs.astype(cfg.dtype) @@ -2268,7 +2345,13 @@ def __call__( # vllm_rpa codepath uses fused_moe_func from tpu_inference for optimized inference. if cfg.attention == "vllm_rpa": output, lb_loss, bias_updates = self.fused_moe_matmul( - inputs, gate_logits, wo_kernel, w0_kernel=w0_kernel, w1_kernel=w1_kernel, fused_kernel=fused_kernel + inputs, + gate_logits, + wo_kernel, + w0_kernel=w0_kernel, + w1_kernel=w1_kernel, + fused_kernel=fused_kernel, + forced_routed_experts=forced_routed_experts, ) elif cfg.sparse_matmul: if quantizations.in_serve_mode(self.quant): @@ -2284,11 +2367,29 @@ def __call__( wo_bias, ) output, lb_loss, bias_updates = self.sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + forced_routed_experts=forced_routed_experts, ) else: output, lb_loss, bias_updates = self.dense_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + forced_routed_experts=forced_routed_experts, ) return output, lb_loss, bias_updates @@ -2375,9 +2476,10 @@ def __call__( gate_inputs: jax.Array | None = None, intermediate_sharding: NamedSharding | None = None, out_sharding: NamedSharding | None = None, + forced_routed_experts: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: routed_experts, load_balance_loss, moe_bias_updates = self.routed_moe( - inputs, gate_inputs=gate_inputs, out_sharding=out_sharding + inputs, gate_inputs=gate_inputs, out_sharding=out_sharding, forced_routed_experts=forced_routed_experts ) shared_experts = self.shared_experts(inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding) return routed_experts + shared_experts, load_balance_loss, moe_bias_updates diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 606e81afd1..9d8b7a65bd 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1066,10 +1066,14 @@ def __call__( attention_metadata=None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, multimodal_input: None | MultimodalInput = None, + forced_routed_experts: jnp.ndarray | None = None, ): cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] + if cfg.scan_layers and forced_routed_experts is not None: + raise NotImplementedError("Forced routing with scanned layers is not supported yet.") + policy = self.get_remat_policy() # [batch, length] -> [batch, length, emb_dim] @@ -1099,6 +1103,9 @@ def __call__( if attention_metadata is not None: layer_kwargs["attention_metadata"] = attention_metadata + if forced_routed_experts is not None: + layer_kwargs["forced_routed_experts"] = forced_routed_experts + if cfg.scan_layers: if self.is_deepseek: layer_kwargs = { @@ -1205,17 +1212,17 @@ def __call__( prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) # Hoisted function to preserve XLA cache ID - def pure_layer_fn(graphdef, state_in, y_in, kv_in): + def pure_layer_fn(graphdef, state_in, y_in, kv_in, valid_kwargs): if cfg.parameter_memory_host_offload: state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in) merged_layer = nnx.merge(graphdef, state_in) - out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **valid_kwargs) return out_y, out_kv, nnx.state(merged_layer) checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - + moe_lyr_idx = 0 for lyr, layer in enumerate(self.layers): graphdef, state = nnx.split(layer) if kv_caches is not None: @@ -1233,7 +1240,33 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): if input_tokens is not None: layer_kwargs["decoder_input_tokens"] = input_tokens - y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) + # Slice forced_routed_experts for this specific layer index! + current_kwargs = dict(layer_kwargs) + + is_moe = False + if cfg.decoder_block in ( + DecoderBlockType.MIXTRAL, + DecoderBlockType.QWEN3_MOE, + DecoderBlockType.QWEN3_NEXT, + DecoderBlockType.QWEN3_5, + DecoderBlockType.QWEN3_CUSTOM_MOE, + ): + is_moe = True + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK: + is_moe = lyr >= cfg.first_num_dense_layers + elif cfg.decoder_block == DecoderBlockType.LLAMA4: + is_moe = llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step) + + if is_moe and "forced_routed_experts" in current_kwargs and current_kwargs["forced_routed_experts"] is not None: + routed_experts = current_kwargs["forced_routed_experts"] + current_kwargs["forced_routed_experts"] = routed_experts[:, :, moe_lyr_idx, :] + moe_lyr_idx += 1 + else: + current_kwargs.pop("forced_routed_experts", None) + if is_moe: + moe_lyr_idx += 1 # Still increment counter if it is an MoE layer but data is missing! + + y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache, current_kwargs) nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 7bb532ae7f..56080fe23f 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -34,6 +34,7 @@ import jax from jax import tree_util as jtu import qwix +import inspect M = tp.TypeVar("M", bound=Module) @@ -417,8 +418,27 @@ class ToLinen(linen.Module): # generic function to augment original nnx module (i.e for learn-to-init distillation) nnx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None + def __post_init__(self): + super().__post_init__() + + clashing_params = {} + _call_fn = getattr(self.nnx_class, "__call__", None) + if _call_fn and callable(_call_fn): + sig = inspect.signature(_call_fn) + params = list(sig.parameters.keys()) + for name in ("previous_chunk", "page_state", "slot"): + if name in params: + clashing_params[name] = params.index(name) - 1 + object.__setattr__(self, "_clashing_params", clashing_params) + @linen.compact def __call__(self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs): + # Pop pre-partialled keyword arguments from kwargs if they are also passed positionally in args + # to avoid Python's multiple values clashing (e.g., in Linen scanned loops). + for param_name, positional_idx in getattr(self, "_clashing_params", {}).items(): + if len(args) > positional_idx: + kwargs.pop(param_name, None) + def _module_kwargs(): maybe_add_default = not self.is_initializing() module_kwargs = dict(self.kwargs) diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 0980b78599..dd10dbb081 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -356,6 +356,7 @@ def __call__( kv_cache=None, attention_metadata=None, decoder_input_tokens=None, + forced_routed_experts=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): @@ -440,6 +441,7 @@ def __call__( kv_cache=None, attention_metadata=None, decoder_input_tokens=None, + forced_routed_experts=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): @@ -602,15 +604,20 @@ def extract_fn(x): load_balance_loss = metadata["load_balance_loss"] moe_bias_updates = metadata["moe_bias_updates"] else: - mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic) + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op( + hidden_states, deterministic, forced_routed_experts=forced_routed_experts + ) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout_op(layer_output, deterministic=deterministic) return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) - def mlp_op(self, x, deterministic, *args, **kwargs): + def mlp_op(self, x, deterministic, forced_routed_experts=None): mlp_lnx, load_balance_loss, moe_bias_updates = self.DeepSeekMoeBlock_0( - x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding + x, + intermediate_sharding=self.mlp_intermediate_sharding, + out_sharding=self.out_sharding, + forced_routed_experts=forced_routed_experts, ) return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates diff --git a/src/maxtext/models/gemma4.py b/src/maxtext/models/gemma4.py index 015758d50f..144cb179f0 100644 --- a/src/maxtext/models/gemma4.py +++ b/src/maxtext/models/gemma4.py @@ -119,6 +119,7 @@ def __call__( original_inputs: jax.Array | None = None, intermediate_sharding: jax.sharding.NamedSharding | None = None, out_sharding: jax.sharding.NamedSharding | None = None, + forced_routed_experts: jax.Array | None = None, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: shared_experts = self.moe_block.shared_experts( inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding @@ -138,7 +139,7 @@ def __call__( # 3. Pass both to routed_moe routed_experts, load_balance_loss, moe_bias_updates = self.moe_block.routed_moe( - routed_inputs, gate_inputs=gate_inputs, out_sharding=out_sharding + routed_inputs, gate_inputs=gate_inputs, out_sharding=out_sharding, forced_routed_experts=forced_routed_experts ) routed_experts = self.post_feedforward_layernorm_2(routed_experts) @@ -319,6 +320,7 @@ def __call__( bidirectional_mask=None, kv_cache=None, attention_metadata=None, + forced_routed_experts: jnp.ndarray | None = None, ): cfg = self.config # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) @@ -363,7 +365,9 @@ def __call__( # MLP block. if getattr(self.config, "num_experts", 1) > 1: - mlp_lnx, load_balance_loss, _ = self.mlp(attn_output, original_inputs=attention_lnx) + mlp_lnx, load_balance_loss, _ = self.mlp( + attn_output, original_inputs=attention_lnx, forced_routed_experts=forced_routed_experts + ) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) else: diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index faf69273c6..7bc06f3109 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -135,6 +135,7 @@ def __call__( slot=None, kv_cache=None, attention_metadata=None, + forced_routed_experts: jnp.ndarray | None = None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): @@ -168,7 +169,7 @@ def __call__( # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states) + mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states, forced_routed_experts=forced_routed_experts) mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) layer_output = mlp_lnx + intermediate_inputs diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 1b0d4b4cd3..2d31a304e6 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -140,6 +140,7 @@ def __call__( nnx_method=None, kv_caches: list[jax.Array] | None = None, attention_metadata: dict[str, Any] | None = None, + forced_routed_experts: jnp.ndarray | None = None, ): """Applies Transformer decoder-branch on encoded-input and target. @@ -199,6 +200,7 @@ def __call__( kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, + forced_routed_experts=forced_routed_experts, ) # pytype: disable=wrong-keyword-args # If we are initializing the model AND MTP is enabled, we must create @@ -430,6 +432,7 @@ def __call__( decoder_target_mask: jax.Array | None = None, kv_caches: list[jax.Array] | None = None, attention_metadata: dict[str, Any] | None = None, + forced_routed_experts: jnp.ndarray | None = None, ): """Applies the Zero-1 FSDP wrapped Transformer model. @@ -456,6 +459,7 @@ def __call__( Returns: Logits from the Transformer model. Logits, hidden_state, kv_caches if called by vLLM. """ + if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( f"During autoregressive decoding we assume the tokens are in the active sequence" @@ -513,6 +517,7 @@ def __call__( kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, + forced_routed_experts=forced_routed_experts, ) # pytype: disable=wrong-keyword-args else: logits, hidden_state, kv_caches = self.decoder( @@ -529,6 +534,7 @@ def __call__( kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, + forced_routed_experts=forced_routed_experts, mutable=mutable_collections, ) # pytype: disable=wrong-keyword-args diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..4a4bc7cc5c 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -821,7 +821,9 @@ def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rn rngs=rngs, ) - def __call__(self, hidden_states: Array, deterministic: bool) -> tuple[Array, Array | None]: + def __call__( + self, hidden_states: Array, deterministic: bool, forced_routed_experts: jnp.ndarray | None = None + ) -> tuple[Array, Array | None]: """ Applies the sparse MoE block to the input hidden states. @@ -835,7 +837,7 @@ def __call__(self, hidden_states: Array, deterministic: bool) -> tuple[Array, Ar - The load balancing loss from the routed experts, if applicable during training. """ # 1. Apply the routed experts block. - routed_output, load_balance_loss, _ = self.routed_experts(hidden_states) + routed_output, load_balance_loss, _ = self.routed_experts(hidden_states, forced_routed_experts=forced_routed_experts) # 2. Apply the shared expert. shared_expert_output = self.shared_expert(hidden_states, deterministic=deterministic) @@ -1012,6 +1014,7 @@ def __call__( slot: None | int = None, kv_cache: None | dict[str, Array] = None, attention_metadata: None | dict[str, Any] = None, + forced_routed_experts: jnp.ndarray | None = None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): @@ -1054,7 +1057,9 @@ def __call__( hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) # Instantiate and call our `Qwen3NextSparseMoeBlock`. - mlp_output, load_balance_loss = self.mlp(hidden_states, deterministic=deterministic) + mlp_output, load_balance_loss = self.mlp( + hidden_states, deterministic=deterministic, forced_routed_experts=forced_routed_experts + ) # We sow the load balancing loss so it can be collected and added to the total loss # during training. @@ -1283,6 +1288,7 @@ def __call__( slot: None | int = None, kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, + forced_routed_experts: jnp.ndarray | None = None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) is_scan_carry = False @@ -1305,7 +1311,7 @@ def __call__( attention_metadata=attention_metadata, ) - mlp_lnx, load_balance_loss, _ = self.moe_block(hidden_states) + mlp_lnx, load_balance_loss, _ = self.moe_block(hidden_states, forced_routed_experts=forced_routed_experts) mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.moe_lb_loss = nnx.Intermediate(load_balance_loss) diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..309257d534 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -185,6 +185,7 @@ def __call__( slot: None | int = None, kv_cache: None | dict[str, Array] = None, attention_metadata: None | dict[str, Any] = None, + forced_routed_experts: jnp.ndarray | None = None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): @@ -227,7 +228,9 @@ def __call__( hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) # Instantiate and call our `Qwen3_5SparseMoEBlock`. - mlp_output, load_balance_loss = self.mlp(hidden_states, deterministic=deterministic) + mlp_output, load_balance_loss = self.mlp( + hidden_states, deterministic=deterministic, forced_routed_experts=forced_routed_experts + ) # We sow the load balancing loss so it can be collected and added to the total loss # during training. diff --git a/tests/unit/forced_routing_test.py b/tests/unit/forced_routing_test.py new file mode 100644 index 0000000000..343ae885fb --- /dev/null +++ b/tests/unit/forced_routing_test.py @@ -0,0 +1,109 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for forced routing in moe.py.""" + +import unittest +import jax +import jax.numpy as jnp +from maxtext.layers import moe +from maxtext.common import common_types as ctypes + + +class DummyConfig: + + def __init__(self, model_name="default", decoder_block=ctypes.DecoderBlockType.DEFAULT): + self.model_name = model_name + self.decoder_block = decoder_block + self.norm_topk_prob = False + self.use_random_routing = False + self.shard_mode = ctypes.ShardMode.AUTO + + +class DummyRoutedMoE: + + def __init__(self, config): + self.config = config + self.dtype = jnp.float32 + self.num_experts_per_tok = 2 + self.num_experts = 3 + + def _maybe_shard_with_logical(self, x, spec): + return x + + +class ForcedRoutingTest(unittest.TestCase): + + def test_basic_override(self): + config = DummyConfig() + model = DummyRoutedMoE(config) + + gate_logits = jnp.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) # (1, 2, 3) + pre_bias_logits = gate_logits # Not DeepSeek + forced_routed_experts = jnp.array([[[2, 1], [0, 2]]]) # (1, 2, 2) + + top_k_weights, top_k_indices = moe.RoutedMoE.get_topk( + model, gate_logits, pre_bias_logits, forced_routed_experts=forced_routed_experts + ) + + # Check that indices are overridden + self.assertTrue((top_k_indices == forced_routed_experts).all()) + # Check that weights are extracted correctly and softmaxed + # For token 0: indices 2, 1 -> logits 3.0, 2.0 -> softmax([3.0, 2.0]) + # For token 1: indices 0, 2 -> logits 4.0, 6.0 -> softmax([4.0, 6.0]) + expected_weights = jax.nn.softmax(jnp.array([[[3.0, 2.0], [4.0, 6.0]]]).astype(jnp.float32), axis=-1) + self.assertTrue(jax.numpy.allclose(top_k_weights, expected_weights, rtol=1e-5, atol=1e-5)) + + def test_gemma4_softmax(self): + config = DummyConfig(decoder_block=ctypes.DecoderBlockType.GEMMA4) + model = DummyRoutedMoE(config) + + gate_logits = jnp.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) # (1, 2, 3) + pre_bias_logits = gate_logits + forced_routed_experts = jnp.array([[[2, 1], [0, 2]]]) # (1, 2, 2) + + top_k_weights, top_k_indices = moe.RoutedMoE.get_topk( + model, gate_logits, pre_bias_logits, forced_routed_experts=forced_routed_experts + ) + + # Check that indices are overridden + self.assertTrue((top_k_indices == forced_routed_experts).all()) + + # For Gemma 4, it applies softmax to gate_logits first! + + expected_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1) + expected_weights = jnp.take_along_axis(expected_probs, forced_routed_experts, axis=-1) + + self.assertTrue(jax.numpy.allclose(top_k_weights, expected_weights, rtol=1e-5, atol=1e-5)) + + def test_reshape_and_update_weights(self): + config = DummyConfig() + model = DummyRoutedMoE(config) + + weights = jnp.array([[[0.1, 0.2], [0.3, 0.4]]]) # (1, 2, 2) + indices = jnp.array([[[2, -1], [-1, 1]]]) # (1, 2, 2) + + update_weights = moe.RoutedMoE.reshape_and_update_weights(model, weights, indices, safe_updates=True) + + # Expected shape: (1, 2, 3) where 3 is num_experts! + # For token 0: index 2 -> 0.1. Index -1 -> mapped to 0 but weight 0.0! + # So for expert 0: 0.0. Expert 1: 0.0. Expert 2: 0.1. + # For token 1: index -1 -> mapped to 0 but weight 0.0! Index 1 -> 0.4. + # So for expert 0: 0.0. Expert 1: 0.4. Expert 2: 0.0. + expected_update_weights = jnp.array([[[0.0, 0.0, 0.1], [0.0, 0.4, 0.0]]]) + + self.assertTrue(jax.numpy.allclose(update_weights, expected_update_weights, rtol=1e-5, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main()