-
Notifications
You must be signed in to change notification settings - Fork 465
Add Deepseek and GPT OSS weight mapping for RL. #2995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
158 changes: 158 additions & 0 deletions
158
src/MaxText/integration/tunix/weight_mapping/deepseek3.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| # Copyright 2023–2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys.""" | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class DEEPSEEK_VLLM_MAPPING: | ||
| """Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys.""" | ||
|
|
||
| @staticmethod | ||
| def to_hf_hook_fns(): | ||
| def flatten_3d_to_2d(val): | ||
| # Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim) | ||
| if val.ndim == 3: | ||
| return val.reshape(val.shape[0], -1) | ||
| return val | ||
|
|
||
| return { | ||
| # MaxText MLA weights are 3D (Rank, Heads, HeadDim). | ||
| # tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them. | ||
| "base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d, | ||
| "base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d, | ||
| "base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d, | ||
| } | ||
|
|
||
| @staticmethod | ||
| def to_hf_transpose_keys(): | ||
| """Returns a list of keys for weights that need to be transposed. | ||
|
|
||
| Returns: | ||
| An empty dictionary, as no keys require transposition for this mapping. | ||
| """ | ||
| return {} | ||
|
|
||
| @staticmethod | ||
| def lora_to_hf_mappings(): | ||
| """Provides the mapping for LoRA (Low-Rank Adaptation) weights. | ||
|
|
||
| Returns: | ||
| None, as LoRA mappings are not defined for this model. | ||
| """ | ||
| return None | ||
|
|
||
| @staticmethod | ||
| def to_hf_mapping(): | ||
| """Returns the weight mapping for the model.""" | ||
| mapping = { | ||
| # --- Base Model Params --- | ||
| # Map to HF names to be safe with loader regexes | ||
| "base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)), | ||
| "base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)), | ||
| "base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")), | ||
| # MLA LAYERS (Map to HF Keys to trigger loader splitting logic) | ||
| # Norms | ||
| "base.decoder.layers.pre_self_attention_layer_norm.scale": ( | ||
| "model.layers.*.input_layernorm.weight", | ||
| (None, "layer"), | ||
| ), | ||
| "base.decoder.layers.post_self_attention_layer_norm.scale": ( | ||
| "model.layers.*.post_attention_layernorm.weight", | ||
| (None, "layer"), | ||
| ), | ||
| # MLA Norms | ||
| "base.decoder.layers.self_attention.kv_norm.scale": ( | ||
| "model.layers.*.self_attn.kv_a_layernorm.weight", | ||
| (None, "layer"), | ||
| ), | ||
| "base.decoder.layers.self_attention.q_norm.scale": ( | ||
| "model.layers.*.self_attn.q_a_layernorm.weight", | ||
| (None, "layer"), | ||
| ), | ||
| # MLA Projections | ||
| # We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj" | ||
| # and performs the necessary split into k_b and v_b for the MLA kernel. | ||
| "base.decoder.layers.self_attention.wq_a.kernel": ( | ||
| "model.layers.*.self_attn.q_a_proj.weight", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| "base.decoder.layers.self_attention.wq_b.kernel": ( | ||
| "model.layers.*.self_attn.q_b_proj.weight", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| "base.decoder.layers.self_attention.wkv_a.kernel": ( | ||
| "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| "base.decoder.layers.self_attention.wkv_b.kernel": ( | ||
| "model.layers.*.self_attn.kv_b_proj.weight", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| "base.decoder.layers.self_attention.out.kernel": ( | ||
| "model.layers.*.self_attn.o_proj.weight", | ||
| ("model", "layer", None, None), | ||
| ), | ||
| # DENSE MLP LAYERS (Map to vllm keys for safety/consistency) | ||
| "base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")), | ||
| "base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")), | ||
| "base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)), | ||
| # MOE LAYERS (Map to INTERNAL keys to bypass loader stacking) | ||
| # Since MaxText experts are already fused/stacked, we map directly to the | ||
| # internal `tpu-inference` param names. The loader will fail to find | ||
| # "experts.{i}" in the name and fall back to loading these directly, | ||
| # which is exactly what we want for performance. | ||
| # Shared Experts | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": ( | ||
| "layers.*.shared_experts.kernel_gating_DF", | ||
| (None, "layer", "model"), | ||
| ), | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": ( | ||
| "layers.*.shared_experts.kernel_up_proj_DF", | ||
| (None, "layer", "model"), | ||
| ), | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": ( | ||
| "layers.*.shared_experts.kernel_down_proj_FD", | ||
| ("model", "layer", None), | ||
| ), | ||
| # Router | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": ( | ||
| "layers.*.custom_module.router.kernel_DE", | ||
| (None, "layer", "model"), | ||
| ), | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": ( | ||
| "layers.*.custom_module.router.bias_E", | ||
| (None, "layer", "model"), | ||
| ), | ||
| # Routed Experts (Fused) | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": ( | ||
| "layers.*.custom_module.kernel_gating_EDF", | ||
| ("expert", "layer", None, "model"), | ||
| ), | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": ( | ||
| "layers.*.custom_module.kernel_up_proj_EDF", | ||
| ("expert", "layer", None, "model"), | ||
| ), | ||
| "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": ( | ||
| "layers.*.custom_module.kernel_down_proj_EFD", | ||
| ("expert", "layer", "model", None), | ||
| ), | ||
| # MTP BLOCK (Included for completeness, but typically skipped by current loader) | ||
| "base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)), | ||
| "base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)), | ||
| "base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")), | ||
| } | ||
| return mapping |
155 changes: 155 additions & 0 deletions
155
src/MaxText/integration/tunix/weight_mapping/gpt_oss.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # Copyright 2023–2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys.""" | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Dict, Optional, Tuple | ||
|
|
||
|
|
||
| @dataclass | ||
| class GPT_OSS_VLLM_MAPPING: | ||
| """ | ||
| Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. | ||
| Supports: | ||
| - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...) | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def lora_to_hf_mappings(): | ||
| """Provides the mapping for LoRA (Low-Rank Adaptation) weights. | ||
| Returns: | ||
| None, as LoRA mappings are not defined for this model. | ||
| """ | ||
| return None | ||
|
|
||
| @staticmethod | ||
| def to_hf_hook_fns(): | ||
| """Returns hook functions to fuse interleaved weights.""" | ||
| return {} | ||
|
|
||
| @staticmethod | ||
| def to_hf_transpose_keys(): | ||
| """Returns keys that need to be transposed.""" | ||
| return {} | ||
|
|
||
| @staticmethod | ||
| def to_hf_mapping( | ||
| layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo" | ||
| ) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]: | ||
| """Returns the weight mapping for the model. | ||
| Args: | ||
| layer_cycle_interval: The interval at which layers are cycled. | ||
| total_num_layers: The total number of layers in the model. | ||
| interleave_style: The style of interleaving used for the layers. | ||
| Returns: | ||
| A dictionary mapping MaxText parameter names to vLLM parameter names. | ||
| """ | ||
|
|
||
| mapping = {} | ||
|
|
||
| # --- 1. Global Parameters --- | ||
| mapping.update( | ||
| { | ||
| "base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)), | ||
| "base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)), | ||
| "base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")), | ||
| } | ||
| ) | ||
|
|
||
| # --- 2. Layer Mapping Loop --- | ||
| layers_per_block = total_num_layers // layer_cycle_interval | ||
|
|
||
| for block_idx in range(layer_cycle_interval): | ||
| src_block = f"base.decoder.layers.layers_{block_idx}" | ||
| if interleave_style == "modulo": | ||
| target_indices = range(block_idx, total_num_layers, layer_cycle_interval) | ||
| else: | ||
| start = block_idx * layers_per_block | ||
| target_indices = range(start, start + layers_per_block) | ||
|
|
||
| regex_indices = "|".join(map(str, target_indices)) | ||
| layer_regex = f"layers\.({regex_indices})" | ||
|
|
||
| # --- 3. Block Mappings (Standard) --- | ||
| mapping.update( | ||
| { | ||
| f"{src_block}.pre_self_attention_layer_norm.scale": ( | ||
| f"{layer_regex}.pre_attention_norm.scale", | ||
| (None, "layer"), | ||
| ), | ||
| f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")), | ||
| f"{src_block}.GptOssAttention.query.kernel": ( | ||
| f"{layer_regex}.attn.kernel_q_DNH", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| f"{src_block}.GptOssAttention.key.kernel": ( | ||
| f"{layer_regex}.attn.kernel_k_DKH", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| f"{src_block}.GptOssAttention.value.kernel": ( | ||
| f"{layer_regex}.attn.kernel_v_DKH", | ||
| (None, "layer", "model", None), | ||
| ), | ||
| f"{src_block}.GptOssAttention.out.kernel": ( | ||
| f"{layer_regex}.attn.kernel_o_proj_NHD", | ||
| ("model", "layer", None, None), | ||
| ), | ||
| f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)), | ||
| f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)), | ||
| f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)), | ||
| f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")), | ||
| f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")), | ||
| } | ||
| ) | ||
|
|
||
| # MoE Router | ||
| mapping.update( | ||
| { | ||
| f"{src_block}.GptOssMlp.gate.kernel": ( | ||
| f"{layer_regex}.custom_module.router.kernel_DE", | ||
| (None, "layer", "model"), | ||
| ), | ||
| f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")), | ||
| } | ||
| ) | ||
|
|
||
| # --- MOE EXPERTS --- | ||
| # Separate gate_proj (wi_0) and up_proj (wi_1) kernels and biases. | ||
|
|
||
| # MLP Gate Projection (wi_0) | ||
| mapping.update( | ||
| { | ||
| f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.gate_proj_kernel", ("model", "layer", None)), | ||
| f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.gate_proj_bias", ("model", "layer")), | ||
| } | ||
| ) | ||
|
|
||
| # MLP Up Projection (wi_1) | ||
| mapping.update( | ||
| { | ||
| f"{src_block}.GptOssMlp.wi_1": (f"{layer_regex}.custom_module.up_proj_kernel", ("model", "layer", None)), | ||
| f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.up_proj_bias", ("model", "layer")), | ||
| } | ||
| ) | ||
|
|
||
| # MLP Down Projection (wo) | ||
| mapping.update( | ||
| { | ||
| f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)), | ||
| f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")), | ||
| } | ||
| ) | ||
|
|
||
| return mapping |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.