diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 77751479ce..6a6106149a 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -435,7 +435,9 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], @@ -448,6 +450,8 @@ logical_axis_rules: [ ['activation_attn_length_no_exp', ['context']], ['activation_length_no_exp', ['sequence', 'context']], ['activation_length_no_exp', ['context']], + ['activation_length_no_exp_moe', ['sequence', 'context']], + ['activation_length_no_exp_moe', ['context']], ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_q_length', ['context', 'expert']], ['activation_q_length_no_exp', ['context']], @@ -456,6 +460,7 @@ logical_axis_rules: [ ['activation_kv_length', []], ['activation_attn_embed', ['tensor', 'tensor_transpose']], ['activation_embed', ['tensor', 'tensor_transpose']], + ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], @@ -484,6 +489,10 @@ logical_axis_rules: [ ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'context']], ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 62f2dbe370..aece47bf63 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -38,6 +38,7 @@ logical_axis_rules: [ ['activation_q_length', ['expert']], ['activation_attn_embed', ['tensor']], ['activation_embed', ['tensor']], + ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], ['activation_kv', ['tensor']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']], @@ -56,6 +57,7 @@ logical_axis_rules: [ ['kv_heads', ['tensor']], ['embed', ['fsdp', 'expert']], ['embed_no_exp', ['fsdp']], + ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['norm', ['tensor']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index db9aafce8b..6f2bc0f9de 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -27,6 +27,7 @@ logical_axis_rules: [ ['decode_batch', ['fsdp']], ['embed', ['fsdp']], ['embed_no_exp', ['fsdp']], + ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['exp_with_fsdp', 'fsdp'], diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e7f548c847..9054837e1f 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. contract_ind = tuple(range(0, len(norm_axis))) output_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None)) + create_sharding(self.mesh, ("activaton_batch_no_exp_moe", "activaton_length_no_exp_moe", None)) if self.shard_mode == ShardMode.EXPLICIT else None ) @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + self.wi_kernel_axes = ("embed_moe", None, "mlp") + self.wo_kernel_axes = ("embed_moe", "mlp", None) elif self.config.use_2d_fsdp_sharding: - self.wi_kernel_axes = ("embed_no_exp", "mlp", None) - self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + self.wi_kernel_axes = ("embed_moe", "mlp", None) + self.wo_kernel_axes = ("embed_moe", "mlp", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") + self.wi_kernel_axes = ("exp", "embed_moe", "mlp") + self.wo_kernel_axes = ("exp", "mlp", "embed_moe") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -437,7 +437,7 @@ def __init__( if self.config.mlp_bias: wi_bias_axes = ("exp", "activation_mlp") - wo_bias_axes = ("exp", "activation_embed") + wo_bias_axes = ("exp", "activation_embed_moe") wi_bias_shape = (self.num_experts, self.intermediate_dim) wo_bias_shape = (self.num_experts, self.config.emb_dim) self.wi_0_bias = nnx.Param( @@ -1018,7 +1018,7 @@ def gmm( self._expert_parallelism_name in tuple( filter( - lambda tup: tup[0] == "activation_batch", + lambda tup: tup[0] == "activaton_batch_moe", self.config.logical_axis_rules, ) )[ @@ -1028,26 +1028,26 @@ def gmm( except: # pylint: disable=bare-except is_batch_sharded_by_expert = False if is_batch_sharded_by_expert and inputs.shape[0] > 1: - batch_logical_axis = "activation_batch" + batch_logical_axis = "activaton_batch_moe" else: - batch_logical_axis = "activation_batch_no_exp" + batch_logical_axis = "activaton_batch_no_exp_moe" if self.get_tensor_transpose_parallelism_size() > 1: input_partition_pspec = self._logical_to_mesh_axes( - (batch_logical_axis, "activation_norm_length", "activation_embed") + (batch_logical_axis, "activaton_length_no_exp_moe", "activation_embed_moe") ) w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", None)) w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) - gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", None)) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None @@ -1099,7 +1099,7 @@ def gmm( P(), # Replicate the input key ), out_specs=( - self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), + self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", "activation_embed_moe")), P(), # Handle None or replicate the output P(), # Handle None or replicate the output ), @@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose")) if self.get_tensor_transpose_parallelism_size() > 1: - input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed") + input_axes = (batch_logical_axis, "activaton_length_no_exp_moe", "activation_embed_moe") else: - input_axes = (batch_logical_axis, "activation_norm_length", None) + input_axes = (batch_logical_axis, "activaton_length_no_exp_moe", None) - gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) + gate_logits_axes = (batch_logical_axis, "activaton_length_no_exp_moe", None) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) + pre_bias_logits_axes = (batch_logical_axis, "activaton_length_no_exp_moe", None) else: pre_bias_logits_axes = None @@ -1436,13 +1436,13 @@ def reshape_and_update_weights(self, weights, indices): update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( self._maybe_shard_with_logical( - jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None) + jnp.arange(weights.shape[0])[:, None, None], ("activaton_batch_no_exp_moe", None, None) ), - self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activaton_length_no_exp_moe", None)), indices, ) weight_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None)) + create_sharding(self.mesh, ("activaton_batch_no_exp_moe", "activaton_length_no_exp_moe", None)) if self.config.shard_mode == ShardMode.EXPLICIT else None ) @@ -1497,7 +1497,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activaton_batch_moe", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1505,7 +1505,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch", "activation_norm_length", None, None, None), + ("activaton_batch_moe", "activaton_length_no_exp_moe", None, None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) @@ -1585,7 +1585,7 @@ def generate_masks(self, top_k_indices, softmax_probs): expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activaton_batch_moe", None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1593,7 +1593,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch", "activation_norm_length", None, None), + ("activaton_batch_moe", "activaton_length_no_exp_moe", None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) @@ -1691,11 +1691,13 @@ def dense_matmul( ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None)) + gate_logits = self._maybe_shard_with_logical( + gate_logits, ("activaton_batch_moe", "activaton_length_no_exp_moe", None) + ) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits = self._maybe_shard_with_logical( - pre_bias_logits, ("activation_batch", "activation_norm_length", None) + pre_bias_logits, ("activaton_batch_moe", "activaton_length_no_exp_moe", None) ) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 @@ -1735,16 +1737,16 @@ def dense_matmul( dispatch_mask, combine_mask = self.generate_masks( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) - mask_axes = ("activation_batch", "activation_norm_length", None, None) + mask_axes = ("activaton_batch_moe", "activaton_length_no_exp_moe", None, None) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, - "activation_embed", + "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, "activation_mlp", ) @@ -1759,56 +1761,56 @@ def dense_matmul( dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs) if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1: mask_axes = ( - "activation_norm_length", - "activation_batch", + "activaton_length_no_exp_moe", + "activaton_batch_moe", None, None, None, ) input_axis = ( - "activation_norm_length", - "activation_batch", + "activaton_length_no_exp_moe", + "activaton_batch_moe", None, - "activation_embed", + "activation_embed_moe", ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, None, - "activation_embed", + "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, None, "activation_mlp", ) else: mask_axes = ( - "activation_batch", - "activation_norm_length", + "activaton_batch_moe", + "activaton_length_no_exp_moe", None, None, None, ) input_axis = ( - "activation_batch", - "activation_norm_length", + "activaton_batch_moe", + "activaton_length_no_exp_moe", None, - "activation_embed", + "activation_embed_moe", ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, None, - "activation_embed", + "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, None, "activation_mlp", @@ -1834,10 +1836,10 @@ def dense_matmul( dispatch, ( None, - "activation_batch_no_exp", - "activation_norm_length", + "activaton_batch_no_exp_moe", + "activaton_length_no_exp_moe", None, - "activation_embed", + "activation_embed_moe", ), ) dispatch = self._maybe_shard_with_logical( @@ -1897,9 +1899,9 @@ def dense_matmul( intermediate_layer, ( "activation_exp", - "activation_batch_no_exp", + "activaton_batch_no_exp_moe", None, - "activation_embed", + "activation_embed_moe", ), ) intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo") @@ -1922,7 +1924,9 @@ def dense_matmul( ) return output, lb_loss, bias_updates else: - inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = self._maybe_shard_with_logical( + inputs, ("activaton_batch_moe", "activaton_length_no_exp_moe", "activation_embed_moe") + ) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision @@ -2082,7 +2086,7 @@ def __init__( num_experts_per_tok=self.config.num_experts_per_tok, mesh=self.mesh, kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), + kernel_axes=("embed_moe", None), intermediate_dim=self.config.moe_mlp_dim, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype,