@@ -351,16 +351,16 @@ def __init__(
351351
352352 if self .config .shard_exp_on_fsdp :
353353 # special sharding for dsv3
354- self .wi_kernel_axes = ("embed_no_exp " , None , "mlp" )
355- self .wo_kernel_axes = ("embed_no_exp " , "mlp" , None )
354+ self .wi_kernel_axes = ("embed_moe " , None , "mlp" )
355+ self .wo_kernel_axes = ("embed_moe " , "mlp" , None )
356356 elif self .config .use_2d_fsdp_sharding :
357- self .wi_kernel_axes = ("embed_no_exp " , "mlp" , None )
358- self .wo_kernel_axes = ("embed_no_exp " , "mlp" , None )
357+ self .wi_kernel_axes = ("embed_moe " , "mlp" , None )
358+ self .wo_kernel_axes = ("embed_moe " , "mlp" , None )
359359 elif self .config .use_batch_split_schedule :
360360 self .wi_kernel_axes , self .wo_kernel_axes = get_batchsplit_init_kernel_axes ()
361361 else :
362- self .wi_kernel_axes = ("exp" , "embed_no_exp " , "mlp" )
363- self .wo_kernel_axes = ("exp" , "mlp" , "embed_no_exp " )
362+ self .wi_kernel_axes = ("exp" , "embed_moe " , "mlp" )
363+ self .wo_kernel_axes = ("exp" , "mlp" , "embed_moe " )
364364
365365 if self .config .attention == "vllm_rpa" :
366366 # vLLM uses 'model' as the tensor parallelism axis name
@@ -437,7 +437,7 @@ def __init__(
437437
438438 if self .config .mlp_bias :
439439 wi_bias_axes = ("exp" , "activation_mlp" )
440- wo_bias_axes = ("exp" , "activation_embed " )
440+ wo_bias_axes = ("exp" , "activation_embed_moe " )
441441 wi_bias_shape = (self .num_experts , self .intermediate_dim )
442442 wo_bias_shape = (self .num_experts , self .config .emb_dim )
443443 self .wi_0_bias = nnx .Param (
@@ -1034,20 +1034,20 @@ def gmm(
10341034
10351035 if self .get_tensor_transpose_parallelism_size () > 1 :
10361036 input_partition_pspec = self ._logical_to_mesh_axes (
1037- (batch_logical_axis , "activation_norm_length " , "activation_embed " )
1037+ (batch_logical_axis , "activation_norm_length_moe " , "activation_embed_moe " )
10381038 )
10391039 w0_bias_pspec = self ._logical_to_mesh_axes (("exp" , None ))
10401040 w1_bias_pspec = self ._logical_to_mesh_axes (("exp" , None ))
1041- wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed " ))
1041+ wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed_moe " ))
10421042 else :
1043- input_partition_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , None ))
1043+ input_partition_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , None ))
10441044 w0_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_mlp" ))
10451045 w1_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_mlp" ))
1046- wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed " ))
1046+ wo_bias_pspec = self ._logical_to_mesh_axes (("exp" , "activation_embed_moe " ))
10471047
1048- gate_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , None ))
1048+ gate_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , None ))
10491049 if self .config .model_name .startswith ("deepseek3" ):
1050- pre_bias_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , None ))
1050+ pre_bias_logits_pspec = self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , None ))
10511051 else :
10521052 # pre_bias_logits is None for non-DeepSeek v3 models
10531053 pre_bias_logits_pspec = None
@@ -1099,7 +1099,7 @@ def gmm(
10991099 P (), # Replicate the input key
11001100 ),
11011101 out_specs = (
1102- self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length " , "activation_embed " )),
1102+ self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length_moe " , "activation_embed_moe " )),
11031103 P (), # Handle None or replicate the output
11041104 P (), # Handle None or replicate the output
11051105 ),
@@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14111411 wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp_with_fsdp" , "mlp_no_fsdp" , "embed_tensor_transpose" ))
14121412
14131413 if self .get_tensor_transpose_parallelism_size () > 1 :
1414- input_axes = (batch_logical_axis , "activation_norm_length " , "activation_embed " )
1414+ input_axes = (batch_logical_axis , "activation_norm_length_moe " , "activation_embed_moe " )
14151415 else :
1416- input_axes = (batch_logical_axis , "activation_norm_length " , None )
1416+ input_axes = (batch_logical_axis , "activation_norm_length_moe " , None )
14171417
1418- gate_logits_axes = (batch_logical_axis , "activation_norm_length " , None )
1418+ gate_logits_axes = (batch_logical_axis , "activation_norm_length_moe " , None )
14191419 if self .config .model_name .startswith ("deepseek3" ):
1420- pre_bias_logits_axes = (batch_logical_axis , "activation_norm_length " , None )
1420+ pre_bias_logits_axes = (batch_logical_axis , "activation_norm_length_moe " , None )
14211421 else :
14221422 pre_bias_logits_axes = None
14231423
@@ -1505,7 +1505,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
15051505 )
15061506 expert_token_count = self ._maybe_shard_with_logical (
15071507 expert_token_count ,
1508- ("activation_batch" , "activation_norm_length " , None , None , None ),
1508+ ("activation_batch" , "activation_norm_length_moe " , None , None , None ),
15091509 )
15101510 trunc_expert_mask = expert_mask * jnp .less_equal (expert_token_count , expert_capacity_per_batch )
15111511 combined_expert_mask = jnp .sum (trunc_expert_mask , axis = 3 )
@@ -1593,7 +1593,7 @@ def generate_masks(self, top_k_indices, softmax_probs):
15931593 )
15941594 expert_token_count = self ._maybe_shard_with_logical (
15951595 expert_token_count ,
1596- ("activation_batch" , "activation_norm_length " , None , None ),
1596+ ("activation_batch" , "activation_norm_length_moe " , None , None ),
15971597 )
15981598 trunc_expert_mask = expert_mask * jnp .less_equal (expert_token_count , expert_capacity_per_batch )
15991599 combined_expert_mask = jnp .sum (trunc_expert_mask , axis = 2 )
@@ -1691,11 +1691,11 @@ def dense_matmul(
16911691 ) -> tuple [jax .Array , Optional [jax .Array ], Optional [jax .Array ]]:
16921692 """Dense matrix multiplication."""
16931693 # gate_logits: batch, length, expert
1694- gate_logits = self ._maybe_shard_with_logical (gate_logits , ("activation_batch" , "activation_norm_length " , None ))
1694+ gate_logits = self ._maybe_shard_with_logical (gate_logits , ("activation_batch" , "activation_norm_length_moe " , None ))
16951695 if self .config .model_name .startswith ("deepseek3" ):
16961696 # pre_bias_logits is None for non-DeepSeek v3 models
16971697 pre_bias_logits = self ._maybe_shard_with_logical (
1698- pre_bias_logits , ("activation_batch" , "activation_norm_length " , None )
1698+ pre_bias_logits , ("activation_batch" , "activation_norm_length_moe " , None )
16991699 )
17001700 top_k_weights , top_k_indices = self .get_topk (gate_logits , pre_bias_logits , self .rngs )
17011701 is_llama4_decoder_layer = self .config .decoder_block == ctypes .DecoderBlockType .LLAMA4
@@ -1735,12 +1735,12 @@ def dense_matmul(
17351735 dispatch_mask , combine_mask = self .generate_masks (
17361736 top_k_indices , weights # pylint: disable=undefined-variable,possibly-used-before-assignment
17371737 )
1738- mask_axes = ("activation_batch" , "activation_norm_length " , None , None )
1738+ mask_axes = ("activation_batch" , "activation_norm_length_moe " , None , None )
17391739 dispatch_axis = (
17401740 "activation_exp" ,
17411741 "activation_batch_no_exp" ,
17421742 None ,
1743- "activation_embed " ,
1743+ "activation_embed_moe " ,
17441744 )
17451745 mlp_axis = (
17461746 "activation_exp" ,
@@ -1759,24 +1759,24 @@ def dense_matmul(
17591759 dispatch_mask , combine_mask = self .generate_masks_subgroup (top_k_indices , softmax_probs )
17601760 if self .get_context_autoregressive_parallelism_size () > 0 and cp == 1 :
17611761 mask_axes = (
1762- "activation_norm_length " ,
1762+ "activation_norm_length_moe " ,
17631763 "activation_batch" ,
17641764 None ,
17651765 None ,
17661766 None ,
17671767 )
17681768 input_axis = (
1769- "activation_norm_length " ,
1769+ "activation_norm_length_moe " ,
17701770 "activation_batch" ,
17711771 None ,
1772- "activation_embed " ,
1772+ "activation_embed_moe " ,
17731773 )
17741774 dispatch_axis = (
17751775 "activation_exp" ,
17761776 "activation_batch_no_exp" ,
17771777 None ,
17781778 None ,
1779- "activation_embed " ,
1779+ "activation_embed_moe " ,
17801780 )
17811781 mlp_axis = (
17821782 "activation_exp" ,
@@ -1788,23 +1788,23 @@ def dense_matmul(
17881788 else :
17891789 mask_axes = (
17901790 "activation_batch" ,
1791- "activation_norm_length " ,
1791+ "activation_norm_length_moe " ,
17921792 None ,
17931793 None ,
17941794 None ,
17951795 )
17961796 input_axis = (
17971797 "activation_batch" ,
1798- "activation_norm_length " ,
1798+ "activation_norm_length_moe " ,
17991799 None ,
1800- "activation_embed " ,
1800+ "activation_embed_moe " ,
18011801 )
18021802 dispatch_axis = (
18031803 "activation_exp" ,
18041804 "activation_batch_no_exp" ,
18051805 None ,
18061806 None ,
1807- "activation_embed " ,
1807+ "activation_embed_moe " ,
18081808 )
18091809 mlp_axis = (
18101810 "activation_exp" ,
@@ -1835,9 +1835,9 @@ def dense_matmul(
18351835 (
18361836 None ,
18371837 "activation_batch_no_exp" ,
1838- "activation_norm_length " ,
1838+ "activation_norm_length_moe " ,
18391839 None ,
1840- "activation_embed " ,
1840+ "activation_embed_moe " ,
18411841 ),
18421842 )
18431843 dispatch = self ._maybe_shard_with_logical (
@@ -1899,7 +1899,7 @@ def dense_matmul(
18991899 "activation_exp" ,
19001900 "activation_batch_no_exp" ,
19011901 None ,
1902- "activation_embed " ,
1902+ "activation_embed_moe " ,
19031903 ),
19041904 )
19051905 intermediate_layer = adc .checkpoint_name (intermediate_layer , "mlpwo" )
@@ -1922,7 +1922,9 @@ def dense_matmul(
19221922 )
19231923 return output , lb_loss , bias_updates
19241924 else :
1925- inputs = self ._maybe_shard_with_logical (inputs , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
1925+ inputs = self ._maybe_shard_with_logical (
1926+ inputs , ("activation_batch" , "activation_norm_length_moe" , "activation_embed_moe" )
1927+ )
19261928 with jax .named_scope ("wi_0" ):
19271929 layer_w0 = self .get_einsum (rhs_mesh_axes = self .wi_kernel_axes )(
19281930 "BSM,EMH -> BSEH" , inputs , w0_kernel , precision = matmul_precision
@@ -2082,7 +2084,7 @@ def __init__(
20822084 num_experts_per_tok = self .config .num_experts_per_tok ,
20832085 mesh = self .mesh ,
20842086 kernel_init = nd_dense_init (1.0 , "fan_in" , "truncated_normal" ),
2085- kernel_axes = ("embed " , None ),
2087+ kernel_axes = ("embed_moe " , None ),
20862088 intermediate_dim = self .config .moe_mlp_dim ,
20872089 dtype = self .config .dtype ,
20882090 weight_dtype = self .config .weight_dtype ,
0 commit comments