Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,9 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
Expand All @@ -448,6 +450,8 @@ logical_axis_rules: [
['activation_attn_length_no_exp', ['context']],
['activation_length_no_exp', ['sequence', 'context']],
['activation_length_no_exp', ['context']],
['activation_length_no_exp_moe', ['sequence', 'context']],
['activation_length_no_exp_moe', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_q_length', ['context', 'expert']],
['activation_q_length_no_exp', ['context']],
Expand All @@ -456,6 +460,7 @@ logical_axis_rules: [
['activation_kv_length', []],
['activation_attn_embed', ['tensor', 'tensor_transpose']],
['activation_embed', ['tensor', 'tensor_transpose']],
['activation_embed_moe', ['tensor', 'tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
Expand Down Expand Up @@ -484,6 +489,10 @@ logical_axis_rules: [
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_moe', ['fsdp', 'sequence', 'context']],
['embed_tensor_transpose', ['tensor_transpose']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ logical_axis_rules: [
['activation_q_length', ['expert']],
['activation_attn_embed', ['tensor']],
['activation_embed', ['tensor']],
['activation_embed_moe', ['tensor']],
['activation_mlp', ['tensor']],
['activation_kv', ['tensor']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
Expand All @@ -56,6 +57,7 @@ logical_axis_rules: [
['kv_heads', ['tensor']],
['embed', ['fsdp', 'expert']],
['embed_no_exp', ['fsdp']],
['embed_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['norm', ['tensor']],
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ logical_axis_rules: [
['decode_batch', ['fsdp']],
['embed', ['fsdp']],
['embed_no_exp', ['fsdp']],
['embed_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['exp_with_fsdp', 'fsdp'],
Expand Down
120 changes: 62 additions & 58 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.

contract_ind = tuple(range(0, len(norm_axis)))
output_sharding = (
create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None))
create_sharding(self.mesh, ("activaton_batch_no_exp_moe", "activaton_length_no_exp_moe", None))
if self.shard_mode == ShardMode.EXPLICIT
else None
)
Expand Down Expand Up @@ -351,16 +351,16 @@ def __init__(

if self.config.shard_exp_on_fsdp:
# special sharding for dsv3
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
self.wi_kernel_axes = ("embed_moe", None, "mlp")
self.wo_kernel_axes = ("embed_moe", "mlp", None)
elif self.config.use_2d_fsdp_sharding:
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
self.wi_kernel_axes = ("embed_moe", "mlp", None)
self.wo_kernel_axes = ("embed_moe", "mlp", None)
elif self.config.use_batch_split_schedule:
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
else:
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
self.wi_kernel_axes = ("exp", "embed_moe", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_moe")

if self.config.attention == "vllm_rpa":
# vLLM uses 'model' as the tensor parallelism axis name
Expand Down Expand Up @@ -437,7 +437,7 @@ def __init__(

if self.config.mlp_bias:
wi_bias_axes = ("exp", "activation_mlp")
wo_bias_axes = ("exp", "activation_embed")
wo_bias_axes = ("exp", "activation_embed_moe")
wi_bias_shape = (self.num_experts, self.intermediate_dim)
wo_bias_shape = (self.num_experts, self.config.emb_dim)
self.wi_0_bias = nnx.Param(
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def gmm(
self._expert_parallelism_name
in tuple(
filter(
lambda tup: tup[0] == "activation_batch",
lambda tup: tup[0] == "activaton_batch_moe",
self.config.logical_axis_rules,
)
)[
Expand All @@ -1028,26 +1028,26 @@ def gmm(
except: # pylint: disable=bare-except
is_batch_sharded_by_expert = False
if is_batch_sharded_by_expert and inputs.shape[0] > 1:
batch_logical_axis = "activation_batch"
batch_logical_axis = "activaton_batch_moe"
else:
batch_logical_axis = "activation_batch_no_exp"
batch_logical_axis = "activaton_batch_no_exp_moe"

if self.get_tensor_transpose_parallelism_size() > 1:
input_partition_pspec = self._logical_to_mesh_axes(
(batch_logical_axis, "activation_norm_length", "activation_embed")
(batch_logical_axis, "activaton_length_no_exp_moe", "activation_embed_moe")
)
w0_bias_pspec = self._logical_to_mesh_axes(("exp", None))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
else:
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", None))
w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))

gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", None))
if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", None))
else:
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits_pspec = None
Expand Down Expand Up @@ -1099,7 +1099,7 @@ def gmm(
P(), # Replicate the input key
),
out_specs=(
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
self._logical_to_mesh_axes((batch_logical_axis, "activaton_length_no_exp_moe", "activation_embed_moe")),
P(), # Handle None or replicate the output
P(), # Handle None or replicate the output
),
Expand Down Expand Up @@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose"))

if self.get_tensor_transpose_parallelism_size() > 1:
input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed")
input_axes = (batch_logical_axis, "activaton_length_no_exp_moe", "activation_embed_moe")
else:
input_axes = (batch_logical_axis, "activation_norm_length", None)
input_axes = (batch_logical_axis, "activaton_length_no_exp_moe", None)

gate_logits_axes = (batch_logical_axis, "activation_norm_length", None)
gate_logits_axes = (batch_logical_axis, "activaton_length_no_exp_moe", None)
if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None)
pre_bias_logits_axes = (batch_logical_axis, "activaton_length_no_exp_moe", None)
else:
pre_bias_logits_axes = None

Expand All @@ -1436,13 +1436,13 @@ def reshape_and_update_weights(self, weights, indices):
update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype)
index_update = (
self._maybe_shard_with_logical(
jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None)
jnp.arange(weights.shape[0])[:, None, None], ("activaton_batch_no_exp_moe", None, None)
),
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)),
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activaton_length_no_exp_moe", None)),
indices,
)
weight_sharding = (
create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None))
create_sharding(self.mesh, ("activaton_batch_no_exp_moe", "activaton_length_no_exp_moe", None))
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
Expand Down Expand Up @@ -1497,15 +1497,15 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
expert_mask,
(batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts),
)
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None))
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activaton_batch_moe", None, None, None))
expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2)
expert_token_count = jnp.reshape(
expert_token_count_fused,
((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)),
)
expert_token_count = self._maybe_shard_with_logical(
expert_token_count,
("activation_batch", "activation_norm_length", None, None, None),
("activaton_batch_moe", "activaton_length_no_exp_moe", None, None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3)
Expand Down Expand Up @@ -1585,15 +1585,15 @@ def generate_masks(self, top_k_indices, softmax_probs):
expert_mask,
(batch_size, seq_len * self.num_experts_per_tok, self.num_experts),
)
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None))
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activaton_batch_moe", None, None))
expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1)
expert_token_count = jnp.reshape(
expert_token_count_fused,
((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)),
)
expert_token_count = self._maybe_shard_with_logical(
expert_token_count,
("activation_batch", "activation_norm_length", None, None),
("activaton_batch_moe", "activaton_length_no_exp_moe", None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2)
Expand Down Expand Up @@ -1691,11 +1691,13 @@ def dense_matmul(
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
"""Dense matrix multiplication."""
# gate_logits: batch, length, expert
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None))
gate_logits = self._maybe_shard_with_logical(
gate_logits, ("activaton_batch_moe", "activaton_length_no_exp_moe", None)
)
if self.config.model_name.startswith("deepseek3"):
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits = self._maybe_shard_with_logical(
pre_bias_logits, ("activation_batch", "activation_norm_length", None)
pre_bias_logits, ("activaton_batch_moe", "activaton_length_no_exp_moe", None)
)
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
Expand Down Expand Up @@ -1735,16 +1737,16 @@ def dense_matmul(
dispatch_mask, combine_mask = self.generate_masks(
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
)
mask_axes = ("activation_batch", "activation_norm_length", None, None)
mask_axes = ("activaton_batch_moe", "activaton_length_no_exp_moe", None, None)
dispatch_axis = (
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
"activation_embed",
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
"activation_mlp",
)
Expand All @@ -1759,56 +1761,56 @@ def dense_matmul(
dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs)
if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1:
mask_axes = (
"activation_norm_length",
"activation_batch",
"activaton_length_no_exp_moe",
"activaton_batch_moe",
None,
None,
None,
)
input_axis = (
"activation_norm_length",
"activation_batch",
"activaton_length_no_exp_moe",
"activaton_batch_moe",
None,
"activation_embed",
"activation_embed_moe",
)
dispatch_axis = (
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
None,
"activation_embed",
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
None,
"activation_mlp",
)
else:
mask_axes = (
"activation_batch",
"activation_norm_length",
"activaton_batch_moe",
"activaton_length_no_exp_moe",
None,
None,
None,
)
input_axis = (
"activation_batch",
"activation_norm_length",
"activaton_batch_moe",
"activaton_length_no_exp_moe",
None,
"activation_embed",
"activation_embed_moe",
)
dispatch_axis = (
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
None,
"activation_embed",
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
None,
"activation_mlp",
Expand All @@ -1834,10 +1836,10 @@ def dense_matmul(
dispatch,
(
None,
"activation_batch_no_exp",
"activation_norm_length",
"activaton_batch_no_exp_moe",
"activaton_length_no_exp_moe",
None,
"activation_embed",
"activation_embed_moe",
),
)
dispatch = self._maybe_shard_with_logical(
Expand Down Expand Up @@ -1897,9 +1899,9 @@ def dense_matmul(
intermediate_layer,
(
"activation_exp",
"activation_batch_no_exp",
"activaton_batch_no_exp_moe",
None,
"activation_embed",
"activation_embed_moe",
),
)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
Expand All @@ -1922,7 +1924,9 @@ def dense_matmul(
)
return output, lb_loss, bias_updates
else:
inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = self._maybe_shard_with_logical(
inputs, ("activaton_batch_moe", "activaton_length_no_exp_moe", "activation_embed_moe")
)
with jax.named_scope("wi_0"):
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision
Expand Down Expand Up @@ -2082,7 +2086,7 @@ def __init__(
num_experts_per_tok=self.config.num_experts_per_tok,
mesh=self.mesh,
kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
kernel_axes=("embed_moe", None),
intermediate_dim=self.config.moe_mlp_dim,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
Expand Down
Loading