|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys.""" |
| 16 | + |
| 17 | +from dataclasses import dataclass |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class DEEPSEEK_VLLM_MAPPING: |
| 22 | + """Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys.""" |
| 23 | + |
| 24 | + @staticmethod |
| 25 | + def to_hf_hook_fns(): |
| 26 | + def flatten_3d_to_2d(val): |
| 27 | + # Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim) |
| 28 | + if val.ndim == 3: |
| 29 | + return val.reshape(val.shape[0], -1) |
| 30 | + return val |
| 31 | + |
| 32 | + return { |
| 33 | + # MaxText MLA weights are 3D (Rank, Heads, HeadDim). |
| 34 | + # tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them. |
| 35 | + "base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d, |
| 36 | + "base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d, |
| 37 | + "base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d, |
| 38 | + } |
| 39 | + |
| 40 | + @staticmethod |
| 41 | + def to_hf_transpose_keys(): |
| 42 | + """Returns a list of keys for weights that need to be transposed. |
| 43 | +
|
| 44 | + Returns: |
| 45 | + An empty dictionary, as no keys require transposition for this mapping. |
| 46 | + """ |
| 47 | + return {} |
| 48 | + |
| 49 | + @staticmethod |
| 50 | + def lora_to_hf_mappings(): |
| 51 | + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + None, as LoRA mappings are not defined for this model. |
| 55 | + """ |
| 56 | + return None |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + def to_hf_mapping(): |
| 60 | + mapping = { |
| 61 | + # --- Base Model Params --- |
| 62 | + # Map to HF names to be safe with loader regexes |
| 63 | + "base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)), |
| 64 | + "base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)), |
| 65 | + "base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")), |
| 66 | + # MLA LAYERS (Map to HF Keys to trigger loader splitting logic) |
| 67 | + # Norms |
| 68 | + "base.decoder.layers.pre_self_attention_layer_norm.scale": ( |
| 69 | + "model.layers.*.input_layernorm.weight", |
| 70 | + (None, "layer"), |
| 71 | + ), |
| 72 | + "base.decoder.layers.post_self_attention_layer_norm.scale": ( |
| 73 | + "model.layers.*.post_attention_layernorm.weight", |
| 74 | + (None, "layer"), |
| 75 | + ), |
| 76 | + # MLA Norms |
| 77 | + "base.decoder.layers.self_attention.kv_norm.scale": ( |
| 78 | + "model.layers.*.self_attn.kv_a_layernorm.weight", |
| 79 | + (None, "layer"), |
| 80 | + ), |
| 81 | + "base.decoder.layers.self_attention.q_norm.scale": ( |
| 82 | + "model.layers.*.self_attn.q_a_layernorm.weight", |
| 83 | + (None, "layer"), |
| 84 | + ), |
| 85 | + # MLA Projections |
| 86 | + # We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj" |
| 87 | + # and performs the necessary split into k_b and v_b for the MLA kernel. |
| 88 | + "base.decoder.layers.self_attention.wq_a.kernel": ( |
| 89 | + "model.layers.*.self_attn.q_a_proj.weight", |
| 90 | + (None, "layer", "model", None), |
| 91 | + ), |
| 92 | + "base.decoder.layers.self_attention.wq_b.kernel": ( |
| 93 | + "model.layers.*.self_attn.q_b_proj.weight", |
| 94 | + (None, "layer", "model", None), |
| 95 | + ), |
| 96 | + "base.decoder.layers.self_attention.wkv_a.kernel": ( |
| 97 | + "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", |
| 98 | + (None, "layer", "model", None), |
| 99 | + ), |
| 100 | + "base.decoder.layers.self_attention.wkv_b.kernel": ( |
| 101 | + "model.layers.*.self_attn.kv_b_proj.weight", |
| 102 | + (None, "layer", "model", None), |
| 103 | + ), |
| 104 | + "base.decoder.layers.self_attention.out.kernel": ( |
| 105 | + "model.layers.*.self_attn.o_proj.weight", |
| 106 | + ("model", "layer", None, None), |
| 107 | + ), |
| 108 | + # DENSE MLP LAYERS (Map to vllm keys for safety/consistency) |
| 109 | + "base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")), |
| 110 | + "base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")), |
| 111 | + "base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)), |
| 112 | + # MOE LAYERS (Map to INTERNAL keys to bypass loader stacking) |
| 113 | + # Since MaxText experts are already fused/stacked, we map directly to the |
| 114 | + # internal `tpu-inference` param names. The loader will fail to find |
| 115 | + # "experts.{i}" in the name and fall back to loading these directly, |
| 116 | + # which is exactly what we want for performance. |
| 117 | + # Shared Experts |
| 118 | + "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": ( |
| 119 | + "layers.*.shared_experts.kernel_gating_DF", |
| 120 | + (None, "layer", "model"), |
| 121 | + ), |
| 122 | + "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": ( |
| 123 | + "layers.*.shared_experts.kernel_up_proj_DF", |
| 124 | + (None, "layer", "model"), |
| 125 | + ), |
| 126 | + "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": ( |
| 127 | + "layers.*.shared_experts.kernel_down_proj_FD", |
| 128 | + ("model", "layer", None), |
| 129 | + ), |
| 130 | + # Router |
| 131 | + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": ( |
| 132 | + "layers.*.custom_module.router.kernel_DE", |
| 133 | + (None, "layer", "model"), |
| 134 | + ), |
| 135 | + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": ( |
| 136 | + "layers.*.custom_module.router.bias_E", |
| 137 | + (None, "layer", "model"), |
| 138 | + ), |
| 139 | + # Routed Experts (Fused) |
| 140 | + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": ( |
| 141 | + "layers.*.custom_module.kernel_gating_EDF", |
| 142 | + ("expert", "layer", None, "model"), |
| 143 | + ), |
| 144 | + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": ( |
| 145 | + "layers.*.custom_module.kernel_up_proj_EDF", |
| 146 | + ("expert", "layer", None, "model"), |
| 147 | + ), |
| 148 | + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": ( |
| 149 | + "layers.*.custom_module.kernel_down_proj_EFD", |
| 150 | + ("expert", "layer", "model", None), |
| 151 | + ), |
| 152 | + # MTP BLOCK (Included for completeness, but typically skipped by current loader) |
| 153 | + "base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)), |
| 154 | + "base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)), |
| 155 | + "base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")), |
| 156 | + } |
| 157 | + return mapping |
0 commit comments