Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/MaxText/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
dispatcher to retrieve the correct weight mapping configuration for a given
model name. This allows for easy extension to support new models.
"""

from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING

Expand All @@ -31,6 +32,10 @@ def __getattr__(self, name):
return LLAMA3_VLLM_MAPPING
elif name.startswith("qwen3"):
return QWEN3_VLLM_MAPPING
elif name.startswith("deepseek3"):
return DEEPSEEK_VLLM_MAPPING
elif name.startswith("gpt-oss"):
return GPT_OSS_VLLM_MAPPING
else:
raise ValueError(f"{name} vLLM weight mapping not found.")

Expand Down
158 changes: 158 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/deepseek3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys."""

from dataclasses import dataclass


@dataclass
class DEEPSEEK_VLLM_MAPPING:
"""Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys."""

@staticmethod
def to_hf_hook_fns():
def flatten_3d_to_2d(val):
# Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim)
if val.ndim == 3:
return val.reshape(val.shape[0], -1)
return val

return {
# MaxText MLA weights are 3D (Rank, Heads, HeadDim).
# tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them.
"base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d,
"base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d,
"base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d,
}

@staticmethod
def to_hf_transpose_keys():
"""Returns a list of keys for weights that need to be transposed.

Returns:
An empty dictionary, as no keys require transposition for this mapping.
"""
return {}

@staticmethod
def lora_to_hf_mappings():
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.

Returns:
None, as LoRA mappings are not defined for this model.
"""
return None

@staticmethod
def to_hf_mapping():
"""Returns the weight mapping for the model."""
mapping = {
# --- Base Model Params ---
# Map to HF names to be safe with loader regexes
"base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)),
"base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)),
"base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")),
# MLA LAYERS (Map to HF Keys to trigger loader splitting logic)
# Norms
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
"model.layers.*.input_layernorm.weight",
(None, "layer"),
),
"base.decoder.layers.post_self_attention_layer_norm.scale": (
"model.layers.*.post_attention_layernorm.weight",
(None, "layer"),
),
# MLA Norms
"base.decoder.layers.self_attention.kv_norm.scale": (
"model.layers.*.self_attn.kv_a_layernorm.weight",
(None, "layer"),
),
"base.decoder.layers.self_attention.q_norm.scale": (
"model.layers.*.self_attn.q_a_layernorm.weight",
(None, "layer"),
),
# MLA Projections
# We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj"
# and performs the necessary split into k_b and v_b for the MLA kernel.
"base.decoder.layers.self_attention.wq_a.kernel": (
"model.layers.*.self_attn.q_a_proj.weight",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.wq_b.kernel": (
"model.layers.*.self_attn.q_b_proj.weight",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.wkv_a.kernel": (
"model.layers.*.self_attn.kv_a_proj_with_mqa.weight",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.wkv_b.kernel": (
"model.layers.*.self_attn.kv_b_proj.weight",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.out.kernel": (
"model.layers.*.self_attn.o_proj.weight",
("model", "layer", None, None),
),
# DENSE MLP LAYERS (Map to vllm keys for safety/consistency)
"base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")),
"base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")),
"base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)),
# MOE LAYERS (Map to INTERNAL keys to bypass loader stacking)
# Since MaxText experts are already fused/stacked, we map directly to the
# internal `tpu-inference` param names. The loader will fail to find
# "experts.{i}" in the name and fall back to loading these directly,
# which is exactly what we want for performance.
# Shared Experts
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": (
"layers.*.shared_experts.kernel_gating_DF",
(None, "layer", "model"),
),
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": (
"layers.*.shared_experts.kernel_up_proj_DF",
(None, "layer", "model"),
),
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": (
"layers.*.shared_experts.kernel_down_proj_FD",
("model", "layer", None),
),
# Router
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": (
"layers.*.custom_module.router.kernel_DE",
(None, "layer", "model"),
),
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": (
"layers.*.custom_module.router.bias_E",
(None, "layer", "model"),
),
# Routed Experts (Fused)
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": (
"layers.*.custom_module.kernel_gating_EDF",
("expert", "layer", None, "model"),
),
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": (
"layers.*.custom_module.kernel_up_proj_EDF",
("expert", "layer", None, "model"),
),
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": (
"layers.*.custom_module.kernel_down_proj_EFD",
("expert", "layer", "model", None),
),
# MTP BLOCK (Included for completeness, but typically skipped by current loader)
"base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)),
"base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)),
"base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")),
}
return mapping
155 changes: 155 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys."""

from dataclasses import dataclass
from typing import Dict, Optional, Tuple


@dataclass
class GPT_OSS_VLLM_MAPPING:
"""
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.
Supports:
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
"""

@staticmethod
def lora_to_hf_mappings():
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
Returns:
None, as LoRA mappings are not defined for this model.
"""
return None

@staticmethod
def to_hf_hook_fns():
"""Returns hook functions to fuse interleaved weights."""
return {}

@staticmethod
def to_hf_transpose_keys():
"""Returns keys that need to be transposed."""
return {}

@staticmethod
def to_hf_mapping(
layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo"
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
"""Returns the weight mapping for the model.
Args:
layer_cycle_interval: The interval at which layers are cycled.
total_num_layers: The total number of layers in the model.
interleave_style: The style of interleaving used for the layers.
Returns:
A dictionary mapping MaxText parameter names to vLLM parameter names.
"""

mapping = {}

# --- 1. Global Parameters ---
mapping.update(
{
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)),
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")),
}
)

# --- 2. Layer Mapping Loop ---
layers_per_block = total_num_layers // layer_cycle_interval

for block_idx in range(layer_cycle_interval):
src_block = f"base.decoder.layers.layers_{block_idx}"
if interleave_style == "modulo":
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
else:
start = block_idx * layers_per_block
target_indices = range(start, start + layers_per_block)

regex_indices = "|".join(map(str, target_indices))
layer_regex = f"layers\.({regex_indices})"

# --- 3. Block Mappings (Standard) ---
mapping.update(
{
f"{src_block}.pre_self_attention_layer_norm.scale": (
f"{layer_regex}.pre_attention_norm.scale",
(None, "layer"),
),
f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")),
f"{src_block}.GptOssAttention.query.kernel": (
f"{layer_regex}.attn.kernel_q_DNH",
(None, "layer", "model", None),
),
f"{src_block}.GptOssAttention.key.kernel": (
f"{layer_regex}.attn.kernel_k_DKH",
(None, "layer", "model", None),
),
f"{src_block}.GptOssAttention.value.kernel": (
f"{layer_regex}.attn.kernel_v_DKH",
(None, "layer", "model", None),
),
f"{src_block}.GptOssAttention.out.kernel": (
f"{layer_regex}.attn.kernel_o_proj_NHD",
("model", "layer", None, None),
),
f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)),
f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)),
f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)),
f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")),
f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")),
}
)

# MoE Router
mapping.update(
{
f"{src_block}.GptOssMlp.gate.kernel": (
f"{layer_regex}.custom_module.router.kernel_DE",
(None, "layer", "model"),
),
f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")),
}
)

# --- MOE EXPERTS ---
# Separate gate_proj (wi_0) and up_proj (wi_1) kernels and biases.

# MLP Gate Projection (wi_0)
mapping.update(
{
f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.gate_proj_kernel", ("model", "layer", None)),
f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.gate_proj_bias", ("model", "layer")),
}
)

# MLP Up Projection (wi_1)
mapping.update(
{
f"{src_block}.GptOssMlp.wi_1": (f"{layer_regex}.custom_module.up_proj_kernel", ("model", "layer", None)),
f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.up_proj_bias", ("model", "layer")),
}
)

# MLP Down Projection (wo)
mapping.update(
{
f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)),
f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")),
}
)

return mapping
Loading