Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ class CustomRule(enum.Enum):
CP_AS_EP = "cp-as-ep" # Support CP and EP together
EP_AS_CP = "ep-as-cp" # Support EP only
PIPELINE_LARGE_MOE = "pipeline-large-moe"
FSDP_2D = "2d-fsdp"
2 changes: 0 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,6 @@ moe_fsdp_use_two_stage_all_gather: false
# Shard the expert dimension of the MLP weights on the FSDP axis.
# This configuration is recommended only when num_experts is a multiple of fsdp_parallelism
shard_exp_on_fsdp: False
# use fsdp and fsdp_transpose axes for sharding the moe weights
use_2d_fsdp_sharding: False

# deepseek moe
first_num_dense_layers: 0 # number of initial dense layers in the model
Expand Down
75 changes: 75 additions & 0 deletions src/maxtext/configs/custom_mesh_and_rule/2d-fsdp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# When scaling to a large number of devices with limited model dimensions,
# introducing an additional FSDP axis prevents sharding limits and improves
# GMM efficiency. This rule demonstrates using both `fsdp` and `fsdp_transpose`
# to enable efficient training across O(1000) chips.

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']]
context_sharding: 'context'
logical_axis_rules: [
# ==========================================
# Vocabulary Embedding
# ==========================================
# Vocab Activations
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
# Vocab Weights
['vocab', []],
['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
# ==========================================
# Attention
# ==========================================
# Attention Activations
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_length_attn', ['context']],
['activation_q_length', ['context']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
# Attention Weights
['q_lora', ['fsdp']],
["q_lora_up_proj", ['fsdp_transpose', 'expert']],
['kv_lora', ['fsdp']],
["kv_lora_up_proj", ['fsdp_transpose', 'expert']],
# ==========================================
# Mixture of Experts (MoE)
# ==========================================
# MoE Activations
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_length_moe', ['context']],
['activation_norm_length_moe', ['context']],
['activation_mlp_moe', []],
['activation_exp', ['expert']],
# MoE Weights
['exp', 'expert'],
['mlp_moe', ['fsdp_transpose']],
['embed_moe', ['fsdp', 'context']],
# ==========================================
# Standard MLP / Dense Layers / Model Structure
# ==========================================
# Dense Activations
['activation_mlp', []],
# Note activation batch and length also get used in vocab
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_length', ['context']],
['activation_norm_length', ['context']],
['activation_embed', []],
['activation_stage', 'stage'],
# General Weights
['mlp', ['fsdp_transpose']],
['embed', ['fsdp', 'context', 'expert']],
['norm', []],
['layers', 'stage'],
]
86 changes: 0 additions & 86 deletions src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

This file was deleted.

40 changes: 28 additions & 12 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ class ProfilerType(str, Enum):
"deepseek2-16b",
"deepseek2-236b",
"deepseek3-671b",
"deepseek3-671b-2dfsdp",
"deepseek3-671b-batchsplit",
"deepseek3-test",
"deepseek3-tiny",
Expand Down Expand Up @@ -705,10 +704,6 @@ class MoEGeneral(BaseModel):
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
"and recommended only when num_experts is a multiple of fsdp_parallelism",
)
use_2d_fsdp_sharding: bool = Field(
False,
description="Use `fsdp` and `fsdp_transpose` axes for 2D FSDP sharding.",
)
norm_topk_prob: bool = Field(
False,
description="Enable top-k probability normalization for router weights (Qwen3-specific).",
Expand Down Expand Up @@ -2994,13 +2989,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"tensor": self.ici_tensor_parallelism,
"tensor_transpose": self.ici_tensor_transpose_parallelism,
"tensor_sequence": self.ici_tensor_sequence_parallelism,
"model": self.ici_tensor_parallelism,
"expert": self.ici_expert_parallelism,
"autoregressive": self.ici_autoregressive_parallelism,
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
}
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]

dcn_map = {
"diloco": self.dcn_diloco_parallelism,
Expand All @@ -3014,12 +3005,37 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"tensor": self.dcn_tensor_parallelism,
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
"model": self.dcn_tensor_parallelism,
"expert": self.dcn_expert_parallelism,
"autoregressive": self.dcn_autoregressive_parallelism,
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
}

# Conditionally include vLLM RPA specific axes
if self.attention == "vllm_rpa":
ici_map.update(
{
"model": self.ici_tensor_parallelism,
"attn_dp": 1,
"attn_dp_expert": 1,
}
)
dcn_map.update(
{
"model": self.dcn_tensor_parallelism,
"attn_dp": 1,
"attn_dp_expert": 1,
}
)

# Validate that any axis with configured parallelism > 1 is present in mesh_axes
for axis, ici_size in ici_map.items():
if axis not in self.mesh_axes:
if ici_size > 1 or dcn_map[axis] > 1:
raise ValueError(
f"Mesh axis '{axis}' has configured parallelism > 1 "
f"(ici: {ici_size}, dcn: {dcn_map[axis]}) "
f"but is not included in self.mesh_axes: {self.mesh_axes}"
)
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]

# Diloco params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
model = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
self.model = nnx.data(model)
7 changes: 0 additions & 7 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,6 @@ def __init__(
# special sharding for dsv3
self.wi_kernel_axes = ("embed_moe", None, "mlp_moe")
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
elif self.config.use_2d_fsdp_sharding:
self.wi_kernel_axes = ("embed_moe", "mlp_moe", None)
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
elif self.config.use_batch_split_schedule:
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
else:
Expand Down Expand Up @@ -1217,10 +1214,6 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert):
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
elif self.config.use_2d_fsdp_sharding:
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
else:
# These are the main shardings used by default - they use funky rules to AG over FSDP.
w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
Expand Down
1 change: 0 additions & 1 deletion tests/unit/configs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def test_gpt_configs(config_file):
os.path.join(CONFIGS_DIR, "models", "deepseek2-236b.yml"),
os.path.join(CONFIGS_DIR, "models", "deepseek3-test.yml"),
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b.yml"),
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-2dfsdp.yml"),
os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-batchsplit.yml"),
]

Expand Down
16 changes: 15 additions & 1 deletion tests/utils/sharding_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@
"cp-as-ep",
("ici_fsdp_parallelism=-1", "ici_context_parallelism=2", "ici_expert_parallelism=2"),
),
(
"deepseek2-16b",
"tpu7x-8",
1,
"2d-fsdp",
("ici_fsdp_parallelism=-1", "ici_fsdp_transpose_parallelism=2"),
),
("qwen3-0.6b", "tpu7x-16", 1, "", ()),
("gpt-oss-20b", "tpu7x-16", 1, "", ()),
("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")),
Expand Down Expand Up @@ -168,7 +175,14 @@ def main(argv: Sequence[str]) -> None:
validate_config(config)
print(f"Sharding debug: {config.debug_sharding}")

rule_name = f"rule_{config.custom_mesh_and_rule}" if config.custom_mesh_and_rule else "rule_default"
# Extract custom_mesh_and_rule directly from argv test case string
custom_mesh_and_rule = ""
for arg in argv:
if arg.startswith("custom_mesh_and_rule="):
custom_mesh_and_rule = arg.split("=", 1)[1]
break

rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default"
# Find overrides from argv to append to rule_name
overrides = []
for arg in argv:
Expand Down
Loading
Loading