Skip to content

Commit 3bceb75

Browse files
Add Deepseek mapping for RL.
1 parent 40b8f0c commit 3bceb75

3 files changed

Lines changed: 320 additions & 1 deletion

File tree

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
dispatcher to retrieve the correct weight mapping configuration for a given
1919
model name. This allows for easy extension to support new models.
2020
"""
21-
21+
from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
22+
from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING
2223
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,10 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("deepseek3"):
36+
return DEEPSEEK_VLLM_MAPPING
37+
elif name.startswith("gpt-oss"):
38+
return GPT_OSS_VLLM_MAPPING
3439
else:
3540
raise ValueError(f"{name} vLLM weight mapping not found.")
3641

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys."""
16+
17+
from dataclasses import dataclass
18+
19+
20+
@dataclass
21+
class DEEPSEEK_VLLM_MAPPING:
22+
"""Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys."""
23+
24+
@staticmethod
25+
def to_hf_hook_fns():
26+
def flatten_3d_to_2d(val):
27+
# Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim)
28+
if val.ndim == 3:
29+
return val.reshape(val.shape[0], -1)
30+
return val
31+
32+
return {
33+
# MaxText MLA weights are 3D (Rank, Heads, HeadDim).
34+
# tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them.
35+
"base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d,
36+
"base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d,
37+
"base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d,
38+
}
39+
40+
@staticmethod
41+
def to_hf_transpose_keys():
42+
"""Returns a list of keys for weights that need to be transposed.
43+
44+
Returns:
45+
An empty dictionary, as no keys require transposition for this mapping.
46+
"""
47+
return {}
48+
49+
@staticmethod
50+
def lora_to_hf_mappings():
51+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
52+
53+
Returns:
54+
None, as LoRA mappings are not defined for this model.
55+
"""
56+
return None
57+
58+
@staticmethod
59+
def to_hf_mapping():
60+
mapping = {
61+
# --- Base Model Params ---
62+
# Map to HF names to be safe with loader regexes
63+
"base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)),
64+
"base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)),
65+
"base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")),
66+
# MLA LAYERS (Map to HF Keys to trigger loader splitting logic)
67+
# Norms
68+
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
69+
"model.layers.*.input_layernorm.weight",
70+
(None, "layer"),
71+
),
72+
"base.decoder.layers.post_self_attention_layer_norm.scale": (
73+
"model.layers.*.post_attention_layernorm.weight",
74+
(None, "layer"),
75+
),
76+
# MLA Norms
77+
"base.decoder.layers.self_attention.kv_norm.scale": (
78+
"model.layers.*.self_attn.kv_a_layernorm.weight",
79+
(None, "layer"),
80+
),
81+
"base.decoder.layers.self_attention.q_norm.scale": (
82+
"model.layers.*.self_attn.q_a_layernorm.weight",
83+
(None, "layer"),
84+
),
85+
# MLA Projections
86+
# We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj"
87+
# and performs the necessary split into k_b and v_b for the MLA kernel.
88+
"base.decoder.layers.self_attention.wq_a.kernel": (
89+
"model.layers.*.self_attn.q_a_proj.weight",
90+
(None, "layer", "model", None),
91+
),
92+
"base.decoder.layers.self_attention.wq_b.kernel": (
93+
"model.layers.*.self_attn.q_b_proj.weight",
94+
(None, "layer", "model", None),
95+
),
96+
"base.decoder.layers.self_attention.wkv_a.kernel": (
97+
"model.layers.*.self_attn.kv_a_proj_with_mqa.weight",
98+
(None, "layer", "model", None),
99+
),
100+
"base.decoder.layers.self_attention.wkv_b.kernel": (
101+
"model.layers.*.self_attn.kv_b_proj.weight",
102+
(None, "layer", "model", None),
103+
),
104+
"base.decoder.layers.self_attention.out.kernel": (
105+
"model.layers.*.self_attn.o_proj.weight",
106+
("model", "layer", None, None),
107+
),
108+
# DENSE MLP LAYERS (Map to vllm keys for safety/consistency)
109+
"base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")),
110+
"base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")),
111+
"base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)),
112+
# MOE LAYERS (Map to INTERNAL keys to bypass loader stacking)
113+
# Since MaxText experts are already fused/stacked, we map directly to the
114+
# internal `tpu-inference` param names. The loader will fail to find
115+
# "experts.{i}" in the name and fall back to loading these directly,
116+
# which is exactly what we want for performance.
117+
# Shared Experts
118+
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": (
119+
"layers.*.shared_experts.kernel_gating_DF",
120+
(None, "layer", "model"),
121+
),
122+
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": (
123+
"layers.*.shared_experts.kernel_up_proj_DF",
124+
(None, "layer", "model"),
125+
),
126+
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": (
127+
"layers.*.shared_experts.kernel_down_proj_FD",
128+
("model", "layer", None),
129+
),
130+
# Router
131+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": (
132+
"layers.*.custom_module.router.kernel_DE",
133+
(None, "layer", "model"),
134+
),
135+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": (
136+
"layers.*.custom_module.router.bias_E",
137+
(None, "layer", "model"),
138+
),
139+
# Routed Experts (Fused)
140+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": (
141+
"layers.*.custom_module.kernel_gating_EDF",
142+
("expert", "layer", None, "model"),
143+
),
144+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": (
145+
"layers.*.custom_module.kernel_up_proj_EDF",
146+
("expert", "layer", None, "model"),
147+
),
148+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": (
149+
"layers.*.custom_module.kernel_down_proj_EFD",
150+
("expert", "layer", "model", None),
151+
),
152+
# MTP BLOCK (Included for completeness, but typically skipped by current loader)
153+
"base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)),
154+
"base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)),
155+
"base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")),
156+
}
157+
return mapping
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys."""
16+
17+
from dataclasses import dataclass
18+
import logging
19+
from typing import Dict, Optional, Tuple
20+
import jax
21+
22+
23+
@dataclass
24+
class GPT_OSS_VLLM_MAPPING:
25+
"""
26+
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.
27+
Supports:
28+
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
29+
"""
30+
31+
@staticmethod
32+
def lora_to_hf_mappings():
33+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
34+
Returns:
35+
None, as LoRA mappings are not defined for this model.
36+
"""
37+
return None
38+
39+
@staticmethod
40+
def to_hf_hook_fns():
41+
"""Returns hook functions to fuse interleaved weights."""
42+
return {}
43+
44+
@staticmethod
45+
def to_hf_transpose_keys():
46+
"""Returns keys that need to be transposed."""
47+
return {}
48+
49+
@staticmethod
50+
def to_hf_mapping(
51+
layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo"
52+
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
53+
"""Returns the weight mapping for the model.
54+
Args:
55+
layer_cycle_interval: The interval at which layers are cycled.
56+
total_num_layers: The total number of layers in the model.
57+
interleave_style: The style of interleaving used for the layers.
58+
Returns:
59+
A dictionary mapping MaxText parameter names to vLLM parameter names.
60+
"""
61+
62+
mapping = {}
63+
64+
# --- 1. Global Parameters ---
65+
mapping.update(
66+
{
67+
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)),
68+
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
69+
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")),
70+
}
71+
)
72+
73+
# --- 2. Layer Mapping Loop ---
74+
layers_per_block = total_num_layers // layer_cycle_interval
75+
76+
for block_idx in range(layer_cycle_interval):
77+
src_block = f"base.decoder.layers.layers_{block_idx}"
78+
if interleave_style == "modulo":
79+
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
80+
else:
81+
start = block_idx * layers_per_block
82+
target_indices = range(start, start + layers_per_block)
83+
84+
regex_indices = "|".join(map(str, target_indices))
85+
layer_regex = f"layers\.({regex_indices})"
86+
87+
# --- 3. Block Mappings (Standard) ---
88+
mapping.update(
89+
{
90+
f"{src_block}.pre_self_attention_layer_norm.scale": (
91+
f"{layer_regex}.pre_attention_norm.scale",
92+
(None, "layer"),
93+
),
94+
f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")),
95+
f"{src_block}.GptOssAttention.query.kernel": (
96+
f"{layer_regex}.attn.kernel_q_DNH",
97+
(None, "layer", "model", None),
98+
),
99+
f"{src_block}.GptOssAttention.key.kernel": (
100+
f"{layer_regex}.attn.kernel_k_DKH",
101+
(None, "layer", "model", None),
102+
),
103+
f"{src_block}.GptOssAttention.value.kernel": (
104+
f"{layer_regex}.attn.kernel_v_DKH",
105+
(None, "layer", "model", None),
106+
),
107+
f"{src_block}.GptOssAttention.out.kernel": (
108+
f"{layer_regex}.attn.kernel_o_proj_NHD",
109+
("model", "layer", None, None),
110+
),
111+
f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)),
112+
f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)),
113+
f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)),
114+
f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")),
115+
f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")),
116+
}
117+
)
118+
119+
# MoE Router
120+
mapping.update(
121+
{
122+
f"{src_block}.GptOssMlp.gate.kernel": (
123+
f"{layer_regex}.custom_module.router.kernel_DE",
124+
(None, "layer", "model"),
125+
),
126+
f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")),
127+
}
128+
)
129+
130+
# --- MOE EXPERTS ---
131+
# Separate gate_proj (wi_0) and up_proj (wi_1) kernels and biases.
132+
133+
# MLP Gate Projection (wi_0)
134+
mapping.update(
135+
{
136+
f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.gate_proj_kernel", ("model", "layer", None)),
137+
f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.gate_proj_bias", ("model", "layer")),
138+
}
139+
)
140+
141+
# MLP Up Projection (wi_1)
142+
mapping.update(
143+
{
144+
f"{src_block}.GptOssMlp.wi_1": (f"{layer_regex}.custom_module.up_proj_kernel", ("model", "layer", None)),
145+
f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.up_proj_bias", ("model", "layer")),
146+
}
147+
)
148+
149+
# MLP Down Projection (wo)
150+
mapping.update(
151+
{
152+
f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)),
153+
f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")),
154+
}
155+
)
156+
157+
return mapping

0 commit comments

Comments
 (0)