From a80c31fa3add5334439586dd49463d4a21477a44 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 13 May 2026 17:22:06 +0000 Subject: [PATCH 1/2] add 2d fsdp custom mesh --- src/maxtext/common/common_types.py | 1 + .../configs/custom_mesh_and_rule/2d-fsdp.yml | 75 + src/maxtext/configs/types.py | 35 +- .../vllm/maxtext_vllm_adapter/adapter.py | 2 +- tests/utils/sharding_dump.py | 16 +- .../input_shardings.json | 178 ++ .../logical_shardings.json | 980 ++++++ .../named_shardings.json | 2696 +++++++++++++++++ 8 files changed, 3974 insertions(+), 9 deletions(-) create mode 100644 src/maxtext/configs/custom_mesh_and_rule/2d-fsdp.yml create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/input_shardings.json create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/logical_shardings.json create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/named_shardings.json diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 86811063a6..5d94758238 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -146,3 +146,4 @@ class CustomRule(enum.Enum): CP_AS_EP = "cp-as-ep" # Support CP and EP together EP_AS_CP = "ep-as-cp" # Support EP only PIPELINE_LARGE_MOE = "pipeline-large-moe" + FSDP_2D = "2d-fsdp" diff --git a/src/maxtext/configs/custom_mesh_and_rule/2d-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/2d-fsdp.yml new file mode 100644 index 0000000000..aaadd394f8 --- /dev/null +++ b/src/maxtext/configs/custom_mesh_and_rule/2d-fsdp.yml @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# When scaling to a large number of devices with limited model dimensions, +# introducing an additional FSDP axis prevents sharding limits and improves +# GMM efficiency. This rule demonstrates using both `fsdp` and `fsdp_transpose` +# to enable efficient training across O(1000) chips. + +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert'] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']] +context_sharding: 'context' +logical_axis_rules: [ + # ========================================== + # Vocabulary Embedding + # ========================================== + # Vocab Activations + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']], + # Vocab Weights + ['vocab', []], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']], + # ========================================== + # Attention + # ========================================== + # Attention Activations + ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_length_attn', ['context']], + ['activation_q_length', ['context']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + # Attention Weights + ['q_lora', ['fsdp']], + ["q_lora_up_proj", ['fsdp_transpose', 'expert']], + ['kv_lora', ['fsdp']], + ["kv_lora_up_proj", ['fsdp_transpose', 'expert']], + # ========================================== + # Mixture of Experts (MoE) + # ========================================== + # MoE Activations + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_length_moe', ['context']], + ['activation_norm_length_moe', ['context']], + ['activation_mlp_moe', []], + ['activation_exp', ['expert']], + # MoE Weights + ['exp', 'expert'], + ['mlp_moe', ['fsdp_transpose']], + ['embed_moe', ['fsdp', 'context']], + # ========================================== + # Standard MLP / Dense Layers / Model Structure + # ========================================== + # Dense Activations + ['activation_mlp', []], + # Note activation batch and length also get used in vocab + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_length', ['context']], + ['activation_norm_length', ['context']], + ['activation_embed', []], + ['activation_stage', 'stage'], + # General Weights + ['mlp', ['fsdp_transpose']], + ['embed', ['fsdp', 'context', 'expert']], + ['norm', []], + ['layers', 'stage'], + ] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c364484c7f..0404953cf6 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2994,13 +2994,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "tensor": self.ici_tensor_parallelism, "tensor_transpose": self.ici_tensor_transpose_parallelism, "tensor_sequence": self.ici_tensor_sequence_parallelism, - "model": self.ici_tensor_parallelism, "expert": self.ici_expert_parallelism, "autoregressive": self.ici_autoregressive_parallelism, - "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - "attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP } - self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] dcn_map = { "diloco": self.dcn_diloco_parallelism, @@ -3014,12 +3010,37 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "tensor": self.dcn_tensor_parallelism, "tensor_transpose": self.dcn_tensor_transpose_parallelism, "tensor_sequence": self.dcn_tensor_sequence_parallelism, - "model": self.dcn_tensor_parallelism, "expert": self.dcn_expert_parallelism, "autoregressive": self.dcn_autoregressive_parallelism, - "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - "attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP } + + # Conditionally include vLLM RPA specific axes + if self.attention == "vllm_rpa": + ici_map.update( + { + "model": self.ici_tensor_parallelism, + "attn_dp": 1, + "attn_dp_expert": 1, + } + ) + dcn_map.update( + { + "model": self.dcn_tensor_parallelism, + "attn_dp": 1, + "attn_dp_expert": 1, + } + ) + + # Validate that any axis with configured parallelism > 1 is present in mesh_axes + for axis, ici_size in ici_map.items(): + if axis not in self.mesh_axes: + if ici_size > 1 or dcn_map[axis] > 1: + raise ValueError( + f"Mesh axis '{axis}' has configured parallelism > 1 " + f"(ici: {ici_size}, dcn: {dcn_map[axis]}) " + f"but is not included in self.mesh_axes: {self.mesh_axes}" + ) + self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] # Diloco params diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 9fa42d6abe..07231f965e 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None: model = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) - self.model = nnx.data(model) \ No newline at end of file + self.model = nnx.data(model) diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index 88958aa0ec..7c8007b63f 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -62,6 +62,13 @@ "cp-as-ep", ("ici_fsdp_parallelism=-1", "ici_context_parallelism=2", "ici_expert_parallelism=2"), ), + ( + "deepseek2-16b", + "tpu7x-8", + 1, + "2d-fsdp", + ("ici_fsdp_parallelism=-1", "ici_fsdp_transpose_parallelism=2"), + ), ("qwen3-0.6b", "tpu7x-16", 1, "", ()), ("gpt-oss-20b", "tpu7x-16", 1, "", ()), ("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")), @@ -168,7 +175,14 @@ def main(argv: Sequence[str]) -> None: validate_config(config) print(f"Sharding debug: {config.debug_sharding}") - rule_name = f"rule_{config.custom_mesh_and_rule}" if config.custom_mesh_and_rule else "rule_default" + # Extract custom_mesh_and_rule directly from argv test case string + custom_mesh_and_rule = "" + for arg in argv: + if arg.startswith("custom_mesh_and_rule="): + custom_mesh_and_rule = arg.split("=", 1)[1] + break + + rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default" # Find overrides from argv to append to rule_name overrides = [] for arg in argv: diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/input_shardings.json new file mode 100644 index 0000000000..df2ae99a0c --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/input_shardings.json @@ -0,0 +1,178 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_op/arr: int8[1,4,4]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(None, None)" + } + }, + { + "attention_op/arr: int32[2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(None,)" + } + }, + { + "attention_op/query: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[96,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_batch_attn', 'activation_length', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "linears/x: bfloat16[96,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "deepseek/mlp: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "deepseek/x: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "moe/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[96,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "moe/w0_kernel: bfloat16[64,2048,1408]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(None, None, None)" + } + }, + { + "moe/w1_kernel: bfloat16[64,2048,1408]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(None, None, None)" + } + }, + { + "moe/wo_kernel: bfloat16[64,1408,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(None, None, None)" + } + }, + { + "linears/x: bfloat16[96,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'fsdp_transpose'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/logical_shardings.json new file mode 100644 index 0000000000..8d30b919f8 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/logical_shardings.json @@ -0,0 +1,980 @@ +{ + ".step": { + "partition_spec": [], + "shape": [] + }, + ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".params/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.count": { + "partition_spec": [], + "shape": [] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[2]/.count": { + "partition_spec": [], + "shape": [] + } +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/named_shardings.json new file mode 100644 index 0000000000..d982ba6334 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_2d-fsdp_ici_fsdp_parallelism=-1_ici_fsdp_transpose_parallelism=2/named_shardings.json @@ -0,0 +1,2696 @@ +{ + ".step": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp_transpose", + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp_transpose", + "expert" + ] + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp", + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "context" + ], + "fsdp_transpose" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "context" + ], + "fsdp_transpose" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp_transpose", + [ + "fsdp", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp_transpose", + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp_transpose", + "expert" + ] + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp", + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".params/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "fsdp_transpose", + "context", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.count": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp_transpose", + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp_transpose", + "expert" + ] + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp", + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "context" + ], + "fsdp_transpose" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "context" + ], + "fsdp_transpose" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp_transpose", + [ + "fsdp", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp_transpose", + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp_transpose", + "expert" + ] + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp", + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "fsdp_transpose", + "context", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp_transpose", + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp_transpose", + "expert" + ] + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp", + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "context" + ], + "fsdp_transpose" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "context" + ], + "fsdp_transpose" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp_transpose", + [ + "fsdp", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + "fsdp_transpose" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp_transpose", + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp_transpose", + "expert" + ] + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + "fsdp", + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "fsdp_transpose", + "context", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[2]/.count": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 4, + "fsdp_transpose": 2, + "context": 1, + "expert": 1 + } + }, + "partition_spec": [], + "shape": [] + } +} \ No newline at end of file From f28f38c40075d45c22238d0cffae03570e1a237e Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 13 May 2026 18:13:28 +0000 Subject: [PATCH 2/2] deprecate old 2dfsdp functions --- src/maxtext/configs/base.yml | 2 - .../configs/models/deepseek3-671b-2dfsdp.yml | 86 ------------------- src/maxtext/configs/types.py | 5 -- src/maxtext/layers/moe.py | 7 -- tests/unit/configs_test.py | 1 - 5 files changed, 101 deletions(-) delete mode 100644 src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index cc2f674fd4..9594c04f38 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -250,8 +250,6 @@ moe_fsdp_use_two_stage_all_gather: false # Shard the expert dimension of the MLP weights on the FSDP axis. # This configuration is recommended only when num_experts is a multiple of fsdp_parallelism shard_exp_on_fsdp: False -# use fsdp and fsdp_transpose axes for sharding the moe weights -use_2d_fsdp_sharding: False # deepseek moe first_num_dense_layers: 0 # number of initial dense layers in the model diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml deleted file mode 100644 index be9422895b..0000000000 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes - -# For DeepSeek default device-limited routing, -# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments. - -base_emb_dim: 7168 -base_num_query_heads: 128 -base_num_kv_heads: 128 -base_mlp_dim: 18432 -base_moe_mlp_dim: 2048 -base_num_decoder_layers: 61 -first_num_dense_layers: 3 -mlp_activations: ["silu","linear"] -vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False -normalization_layer_epsilon: 1.0e-6 -num_experts: 256 -num_experts_per_tok: 8 -shared_experts: 1 -routed_scaling_factor: 2.5 -routed_score_func: "sigmoid" -routed_bias: True -decoder_block: "deepseek" -# MLA -attention_type: "mla" -q_lora_rank: 1536 -kv_lora_rank: 512 -qk_nope_head_dim: 128 -qk_rope_head_dim: 64 -v_head_dim: 128 -mscale: 1.0 -# RoPE -rope_type: "yarn" -rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 -max_position_embeddings: 163840 -original_max_position_embeddings: 4096 -rope_factor: 40 -beta_fast: 32 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False - -override_logical_axis_rules: True -mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] -logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_norm_length', ['context']], - ['activation_norm_length_moe', ['context']], - ['activation_heads', []], - ['activation_stage', 'stage'], - ['embed', ['fsdp']], - ['embed_moe', ['fsdp']], - ['embed_vocab', ['fsdp', 'fsdp_transpose']], - ['q_lora', ['fsdp']], - ['kv_lora', ['fsdp']], - ['layers', 'stage'], - ['q_lora_up_proj', ['fsdp_transpose', 'expert']], - ['kv_lora_up_proj', ['fsdp_transpose', 'expert']], - ['q_heads', ['fsdp_transpose', 'expert']], - ['kv_heads', ['fsdp_transpose', 'expert']], - ['heads', ['fsdp_transpose', 'expert']], - ['mlp', ['fsdp_transpose', 'expert']], - ['mlp_moe', ['fsdp_transpose', 'expert']], - ['diloco', 'diloco'], -] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0404953cf6..c63270ffe3 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -221,7 +221,6 @@ class ProfilerType(str, Enum): "deepseek2-16b", "deepseek2-236b", "deepseek3-671b", - "deepseek3-671b-2dfsdp", "deepseek3-671b-batchsplit", "deepseek3-test", "deepseek3-tiny", @@ -705,10 +704,6 @@ class MoEGeneral(BaseModel): description="Shard the expert dimension of the MLP weights on the FSDP axis, " "and recommended only when num_experts is a multiple of fsdp_parallelism", ) - use_2d_fsdp_sharding: bool = Field( - False, - description="Use `fsdp` and `fsdp_transpose` axes for 2D FSDP sharding.", - ) norm_topk_prob: bool = Field( False, description="Enable top-k probability normalization for router weights (Qwen3-specific).", diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 975e8fe9a2..c0be11cf0f 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -358,9 +358,6 @@ def __init__( # special sharding for dsv3 self.wi_kernel_axes = ("embed_moe", None, "mlp_moe") self.wo_kernel_axes = ("embed_moe", "mlp_moe", None) - elif self.config.use_2d_fsdp_sharding: - self.wi_kernel_axes = ("embed_moe", "mlp_moe", None) - self.wo_kernel_axes = ("embed_moe", "mlp_moe", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: @@ -1217,10 +1214,6 @@ def get_routed_moe_shardings(is_batch_sharded_by_expert): w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp")) w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp")) wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) - elif self.config.use_2d_fsdp_sharding: - w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) - w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) - wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) else: # These are the main shardings used by default - they use funky rules to AG over FSDP. w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) diff --git a/tests/unit/configs_test.py b/tests/unit/configs_test.py index 2a7185faaa..8eae0abf2e 100644 --- a/tests/unit/configs_test.py +++ b/tests/unit/configs_test.py @@ -200,7 +200,6 @@ def test_gpt_configs(config_file): os.path.join(CONFIGS_DIR, "models", "deepseek2-236b.yml"), os.path.join(CONFIGS_DIR, "models", "deepseek3-test.yml"), os.path.join(CONFIGS_DIR, "models", "deepseek3-671b.yml"), - os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-2dfsdp.yml"), os.path.join(CONFIGS_DIR, "models", "deepseek3-671b-batchsplit.yml"), ]