Skip to content

Commit 7331bbf

Browse files
committed
split logical names in moe module
1 parent 785ac61 commit 7331bbf

3 files changed

Lines changed: 47 additions & 37 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,13 +449,15 @@ logical_axis_rules: [
449449
['activation_length_no_exp', ['sequence', 'context']],
450450
['activation_length_no_exp', ['context']],
451451
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
452+
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
452453
['activation_q_length', ['context', 'expert']],
453454
['activation_q_length_no_exp', ['context']],
454455
['prefill_activation_length', ['sequence', 'context']],
455456
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
456457
['activation_kv_length', []],
457458
['activation_attn_embed', ['tensor', 'tensor_transpose']],
458459
['activation_embed', ['tensor', 'tensor_transpose']],
460+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
459461
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
460462
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
461463
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
@@ -484,6 +486,10 @@ logical_axis_rules: [
484486
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
485487
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
486488
['embed_no_exp', ['fsdp', 'sequence', 'context']],
489+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
490+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
491+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
492+
['embed_moe', ['fsdp', 'sequence', 'context']],
487493
['embed_tensor_transpose', ['tensor_transpose']],
488494
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
489495
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ logical_axis_rules: [
3838
['activation_q_length', ['expert']],
3939
['activation_attn_embed', ['tensor']],
4040
['activation_embed', ['tensor']],
41+
['activation_embed_moe', ['tensor']],
4142
['activation_mlp', ['tensor']],
4243
['activation_kv', ['tensor']],
4344
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
@@ -56,6 +57,7 @@ logical_axis_rules: [
5657
['kv_heads', ['tensor']],
5758
['embed', ['fsdp', 'expert']],
5859
['embed_no_exp', ['fsdp']],
60+
['embed_moe', ['fsdp']],
5961
['q_lora', ['fsdp']],
6062
['kv_lora', ['fsdp']],
6163
['norm', ['tensor']],

src/maxtext/layers/moe.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,16 @@ def __init__(
351351

352352
if self.config.shard_exp_on_fsdp:
353353
# special sharding for dsv3
354-
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
355-
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
354+
self.wi_kernel_axes = ("embed_moe", None, "mlp")
355+
self.wo_kernel_axes = ("embed_moe", "mlp", None)
356356
elif self.config.use_2d_fsdp_sharding:
357-
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
358-
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
357+
self.wi_kernel_axes = ("embed_moe", "mlp", None)
358+
self.wo_kernel_axes = ("embed_moe", "mlp", None)
359359
elif self.config.use_batch_split_schedule:
360360
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
361361
else:
362-
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
363-
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
362+
self.wi_kernel_axes = ("exp", "embed_moe", "mlp")
363+
self.wo_kernel_axes = ("exp", "mlp", "embed_moe")
364364

365365
if self.config.attention == "vllm_rpa":
366366
# vLLM uses 'model' as the tensor parallelism axis name
@@ -437,7 +437,7 @@ def __init__(
437437

438438
if self.config.mlp_bias:
439439
wi_bias_axes = ("exp", "activation_mlp")
440-
wo_bias_axes = ("exp", "activation_embed")
440+
wo_bias_axes = ("exp", "activation_embed_moe")
441441
wi_bias_shape = (self.num_experts, self.intermediate_dim)
442442
wo_bias_shape = (self.num_experts, self.config.emb_dim)
443443
self.wi_0_bias = nnx.Param(
@@ -1034,20 +1034,20 @@ def gmm(
10341034

10351035
if self.get_tensor_transpose_parallelism_size() > 1:
10361036
input_partition_pspec = self._logical_to_mesh_axes(
1037-
(batch_logical_axis, "activation_norm_length", "activation_embed")
1037+
(batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")
10381038
)
10391039
w0_bias_pspec = self._logical_to_mesh_axes(("exp", None))
10401040
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
1041-
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
1041+
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
10421042
else:
1043-
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1043+
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
10441044
w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
10451045
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
1046-
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
1046+
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
10471047

1048-
gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1048+
gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
10491049
if self.config.model_name.startswith("deepseek3"):
1050-
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1050+
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
10511051
else:
10521052
# pre_bias_logits is None for non-DeepSeek v3 models
10531053
pre_bias_logits_pspec = None
@@ -1099,7 +1099,7 @@ def gmm(
10991099
P(), # Replicate the input key
11001100
),
11011101
out_specs=(
1102-
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1102+
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")),
11031103
P(), # Handle None or replicate the output
11041104
P(), # Handle None or replicate the output
11051105
),
@@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14111411
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose"))
14121412

14131413
if self.get_tensor_transpose_parallelism_size() > 1:
1414-
input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed")
1414+
input_axes = (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")
14151415
else:
1416-
input_axes = (batch_logical_axis, "activation_norm_length", None)
1416+
input_axes = (batch_logical_axis, "activation_norm_length_moe", None)
14171417

1418-
gate_logits_axes = (batch_logical_axis, "activation_norm_length", None)
1418+
gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None)
14191419
if self.config.model_name.startswith("deepseek3"):
1420-
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None)
1420+
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None)
14211421
else:
14221422
pre_bias_logits_axes = None
14231423

@@ -1505,7 +1505,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
15051505
)
15061506
expert_token_count = self._maybe_shard_with_logical(
15071507
expert_token_count,
1508-
("activation_batch", "activation_norm_length", None, None, None),
1508+
("activation_batch", "activation_norm_length_moe", None, None, None),
15091509
)
15101510
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
15111511
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3)
@@ -1593,7 +1593,7 @@ def generate_masks(self, top_k_indices, softmax_probs):
15931593
)
15941594
expert_token_count = self._maybe_shard_with_logical(
15951595
expert_token_count,
1596-
("activation_batch", "activation_norm_length", None, None),
1596+
("activation_batch", "activation_norm_length_moe", None, None),
15971597
)
15981598
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
15991599
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2)
@@ -1691,11 +1691,11 @@ def dense_matmul(
16911691
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
16921692
"""Dense matrix multiplication."""
16931693
# gate_logits: batch, length, expert
1694-
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None))
1694+
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length_moe", None))
16951695
if self.config.model_name.startswith("deepseek3"):
16961696
# pre_bias_logits is None for non-DeepSeek v3 models
16971697
pre_bias_logits = self._maybe_shard_with_logical(
1698-
pre_bias_logits, ("activation_batch", "activation_norm_length", None)
1698+
pre_bias_logits, ("activation_batch", "activation_norm_length_moe", None)
16991699
)
17001700
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
17011701
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
@@ -1735,12 +1735,12 @@ def dense_matmul(
17351735
dispatch_mask, combine_mask = self.generate_masks(
17361736
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
17371737
)
1738-
mask_axes = ("activation_batch", "activation_norm_length", None, None)
1738+
mask_axes = ("activation_batch", "activation_norm_length_moe", None, None)
17391739
dispatch_axis = (
17401740
"activation_exp",
17411741
"activation_batch_no_exp",
17421742
None,
1743-
"activation_embed",
1743+
"activation_embed_moe",
17441744
)
17451745
mlp_axis = (
17461746
"activation_exp",
@@ -1759,24 +1759,24 @@ def dense_matmul(
17591759
dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs)
17601760
if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1:
17611761
mask_axes = (
1762-
"activation_norm_length",
1762+
"activation_norm_length_moe",
17631763
"activation_batch",
17641764
None,
17651765
None,
17661766
None,
17671767
)
17681768
input_axis = (
1769-
"activation_norm_length",
1769+
"activation_norm_length_moe",
17701770
"activation_batch",
17711771
None,
1772-
"activation_embed",
1772+
"activation_embed_moe",
17731773
)
17741774
dispatch_axis = (
17751775
"activation_exp",
17761776
"activation_batch_no_exp",
17771777
None,
17781778
None,
1779-
"activation_embed",
1779+
"activation_embed_moe",
17801780
)
17811781
mlp_axis = (
17821782
"activation_exp",
@@ -1788,23 +1788,23 @@ def dense_matmul(
17881788
else:
17891789
mask_axes = (
17901790
"activation_batch",
1791-
"activation_norm_length",
1791+
"activation_norm_length_moe",
17921792
None,
17931793
None,
17941794
None,
17951795
)
17961796
input_axis = (
17971797
"activation_batch",
1798-
"activation_norm_length",
1798+
"activation_norm_length_moe",
17991799
None,
1800-
"activation_embed",
1800+
"activation_embed_moe",
18011801
)
18021802
dispatch_axis = (
18031803
"activation_exp",
18041804
"activation_batch_no_exp",
18051805
None,
18061806
None,
1807-
"activation_embed",
1807+
"activation_embed_moe",
18081808
)
18091809
mlp_axis = (
18101810
"activation_exp",
@@ -1835,9 +1835,9 @@ def dense_matmul(
18351835
(
18361836
None,
18371837
"activation_batch_no_exp",
1838-
"activation_norm_length",
1838+
"activation_norm_length_moe",
18391839
None,
1840-
"activation_embed",
1840+
"activation_embed_moe",
18411841
),
18421842
)
18431843
dispatch = self._maybe_shard_with_logical(
@@ -1899,7 +1899,7 @@ def dense_matmul(
18991899
"activation_exp",
19001900
"activation_batch_no_exp",
19011901
None,
1902-
"activation_embed",
1902+
"activation_embed_moe",
19031903
),
19041904
)
19051905
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
@@ -1922,7 +1922,9 @@ def dense_matmul(
19221922
)
19231923
return output, lb_loss, bias_updates
19241924
else:
1925-
inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
1925+
inputs = self._maybe_shard_with_logical(
1926+
inputs, ("activation_batch", "activation_norm_length_moe", "activation_embed_moe")
1927+
)
19261928
with jax.named_scope("wi_0"):
19271929
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
19281930
"BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision
@@ -2082,7 +2084,7 @@ def __init__(
20822084
num_experts_per_tok=self.config.num_experts_per_tok,
20832085
mesh=self.mesh,
20842086
kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"),
2085-
kernel_axes=("embed", None),
2087+
kernel_axes=("embed_moe", None),
20862088
intermediate_dim=self.config.moe_mlp_dim,
20872089
dtype=self.config.dtype,
20882090
weight_dtype=self.config.weight_dtype,

0 commit comments

Comments
 (0)