From aceba0ae87c6e918a165228f7a55c11ef5745980 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Tue, 13 Jan 2026 19:20:45 +0000 Subject: [PATCH] Add Deepseek mapping for RL. --- .../tunix/weight_mapping/__init__.py | 7 +- .../tunix/weight_mapping/deepseek3.py | 158 ++++++++++++++++++ .../tunix/weight_mapping/gpt_oss.py | 155 +++++++++++++++++ 3 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 src/MaxText/integration/tunix/weight_mapping/deepseek3.py create mode 100644 src/MaxText/integration/tunix/weight_mapping/gpt_oss.py diff --git a/src/MaxText/integration/tunix/weight_mapping/__init__.py b/src/MaxText/integration/tunix/weight_mapping/__init__.py index d250ee2fe1..7f7a0dc534 100644 --- a/src/MaxText/integration/tunix/weight_mapping/__init__.py +++ b/src/MaxText/integration/tunix/weight_mapping/__init__.py @@ -18,7 +18,8 @@ dispatcher to retrieve the correct weight mapping configuration for a given model name. This allows for easy extension to support new models. """ - +from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING +from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING @@ -31,6 +32,10 @@ def __getattr__(self, name): return LLAMA3_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING + elif name.startswith("deepseek3"): + return DEEPSEEK_VLLM_MAPPING + elif name.startswith("gpt-oss"): + return GPT_OSS_VLLM_MAPPING else: raise ValueError(f"{name} vLLM weight mapping not found.") diff --git a/src/MaxText/integration/tunix/weight_mapping/deepseek3.py b/src/MaxText/integration/tunix/weight_mapping/deepseek3.py new file mode 100644 index 0000000000..7f8091f798 --- /dev/null +++ b/src/MaxText/integration/tunix/weight_mapping/deepseek3.py @@ -0,0 +1,158 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys.""" + +from dataclasses import dataclass + + +@dataclass +class DEEPSEEK_VLLM_MAPPING: + """Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys.""" + + @staticmethod + def to_hf_hook_fns(): + def flatten_3d_to_2d(val): + # Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim) + if val.ndim == 3: + return val.reshape(val.shape[0], -1) + return val + + return { + # MaxText MLA weights are 3D (Rank, Heads, HeadDim). + # tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them. + "base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d, + "base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d, + "base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d, + } + + @staticmethod + def to_hf_transpose_keys(): + """Returns a list of keys for weights that need to be transposed. + + Returns: + An empty dictionary, as no keys require transposition for this mapping. + """ + return {} + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. + + Returns: + None, as LoRA mappings are not defined for this model. + """ + return None + + @staticmethod + def to_hf_mapping(): + """Returns the weight mapping for the model.""" + mapping = { + # --- Base Model Params --- + # Map to HF names to be safe with loader regexes + "base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)), + "base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)), + "base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")), + # MLA LAYERS (Map to HF Keys to trigger loader splitting logic) + # Norms + "base.decoder.layers.pre_self_attention_layer_norm.scale": ( + "model.layers.*.input_layernorm.weight", + (None, "layer"), + ), + "base.decoder.layers.post_self_attention_layer_norm.scale": ( + "model.layers.*.post_attention_layernorm.weight", + (None, "layer"), + ), + # MLA Norms + "base.decoder.layers.self_attention.kv_norm.scale": ( + "model.layers.*.self_attn.kv_a_layernorm.weight", + (None, "layer"), + ), + "base.decoder.layers.self_attention.q_norm.scale": ( + "model.layers.*.self_attn.q_a_layernorm.weight", + (None, "layer"), + ), + # MLA Projections + # We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj" + # and performs the necessary split into k_b and v_b for the MLA kernel. + "base.decoder.layers.self_attention.wq_a.kernel": ( + "model.layers.*.self_attn.q_a_proj.weight", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.wq_b.kernel": ( + "model.layers.*.self_attn.q_b_proj.weight", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.wkv_a.kernel": ( + "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.wkv_b.kernel": ( + "model.layers.*.self_attn.kv_b_proj.weight", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.out.kernel": ( + "model.layers.*.self_attn.o_proj.weight", + ("model", "layer", None, None), + ), + # DENSE MLP LAYERS (Map to vllm keys for safety/consistency) + "base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")), + "base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")), + "base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)), + # MOE LAYERS (Map to INTERNAL keys to bypass loader stacking) + # Since MaxText experts are already fused/stacked, we map directly to the + # internal `tpu-inference` param names. The loader will fail to find + # "experts.{i}" in the name and fall back to loading these directly, + # which is exactly what we want for performance. + # Shared Experts + "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": ( + "layers.*.shared_experts.kernel_gating_DF", + (None, "layer", "model"), + ), + "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": ( + "layers.*.shared_experts.kernel_up_proj_DF", + (None, "layer", "model"), + ), + "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": ( + "layers.*.shared_experts.kernel_down_proj_FD", + ("model", "layer", None), + ), + # Router + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": ( + "layers.*.custom_module.router.kernel_DE", + (None, "layer", "model"), + ), + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": ( + "layers.*.custom_module.router.bias_E", + (None, "layer", "model"), + ), + # Routed Experts (Fused) + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": ( + "layers.*.custom_module.kernel_gating_EDF", + ("expert", "layer", None, "model"), + ), + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": ( + "layers.*.custom_module.kernel_up_proj_EDF", + ("expert", "layer", None, "model"), + ), + "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": ( + "layers.*.custom_module.kernel_down_proj_EFD", + ("expert", "layer", "model", None), + ), + # MTP BLOCK (Included for completeness, but typically skipped by current loader) + "base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)), + "base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)), + "base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")), + } + return mapping diff --git a/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py b/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py new file mode 100644 index 0000000000..ce004bb4c1 --- /dev/null +++ b/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py @@ -0,0 +1,155 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys.""" + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + + +@dataclass +class GPT_OSS_VLLM_MAPPING: + """ + Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. + Supports: + - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...) + """ + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. + Returns: + None, as LoRA mappings are not defined for this model. + """ + return None + + @staticmethod + def to_hf_hook_fns(): + """Returns hook functions to fuse interleaved weights.""" + return {} + + @staticmethod + def to_hf_transpose_keys(): + """Returns keys that need to be transposed.""" + return {} + + @staticmethod + def to_hf_mapping( + layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo" + ) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]: + """Returns the weight mapping for the model. + Args: + layer_cycle_interval: The interval at which layers are cycled. + total_num_layers: The total number of layers in the model. + interleave_style: The style of interleaving used for the layers. + Returns: + A dictionary mapping MaxText parameter names to vLLM parameter names. + """ + + mapping = {} + + # --- 1. Global Parameters --- + mapping.update( + { + "base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)), + "base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)), + "base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")), + } + ) + + # --- 2. Layer Mapping Loop --- + layers_per_block = total_num_layers // layer_cycle_interval + + for block_idx in range(layer_cycle_interval): + src_block = f"base.decoder.layers.layers_{block_idx}" + if interleave_style == "modulo": + target_indices = range(block_idx, total_num_layers, layer_cycle_interval) + else: + start = block_idx * layers_per_block + target_indices = range(start, start + layers_per_block) + + regex_indices = "|".join(map(str, target_indices)) + layer_regex = f"layers\.({regex_indices})" + + # --- 3. Block Mappings (Standard) --- + mapping.update( + { + f"{src_block}.pre_self_attention_layer_norm.scale": ( + f"{layer_regex}.pre_attention_norm.scale", + (None, "layer"), + ), + f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")), + f"{src_block}.GptOssAttention.query.kernel": ( + f"{layer_regex}.attn.kernel_q_DNH", + (None, "layer", "model", None), + ), + f"{src_block}.GptOssAttention.key.kernel": ( + f"{layer_regex}.attn.kernel_k_DKH", + (None, "layer", "model", None), + ), + f"{src_block}.GptOssAttention.value.kernel": ( + f"{layer_regex}.attn.kernel_v_DKH", + (None, "layer", "model", None), + ), + f"{src_block}.GptOssAttention.out.kernel": ( + f"{layer_regex}.attn.kernel_o_proj_NHD", + ("model", "layer", None, None), + ), + f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)), + f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)), + f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)), + f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")), + f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")), + } + ) + + # MoE Router + mapping.update( + { + f"{src_block}.GptOssMlp.gate.kernel": ( + f"{layer_regex}.custom_module.router.kernel_DE", + (None, "layer", "model"), + ), + f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")), + } + ) + + # --- MOE EXPERTS --- + # Separate gate_proj (wi_0) and up_proj (wi_1) kernels and biases. + + # MLP Gate Projection (wi_0) + mapping.update( + { + f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.gate_proj_kernel", ("model", "layer", None)), + f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.gate_proj_bias", ("model", "layer")), + } + ) + + # MLP Up Projection (wi_1) + mapping.update( + { + f"{src_block}.GptOssMlp.wi_1": (f"{layer_regex}.custom_module.up_proj_kernel", ("model", "layer", None)), + f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.up_proj_bias", ("model", "layer")), + } + ) + + # MLP Down Projection (wo) + mapping.update( + { + f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)), + f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")), + } + ) + + return mapping