diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index 280dbf677d..86fc6b146a 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape): return mapping -def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): - """Returns mapping from MaxText to HuggingFace Qwen3 weight paths. +def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """Returns mapping from MaxText to HuggingFace Qwen weight paths. This function generates a dictionary that maps parameter names from a MaxText - Qwen3 checkpoint to their corresponding names in the Hugging Face format. + Qwen checkpoint to their corresponding names in the Hugging Face format. It handles both dense and Mixture-of-Experts (MoE) model variants. Args: @@ -631,6 +631,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) "params-decoder-layers-self_attention-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers) ], + "params-decoder-layers-self_attention-query-bias": [ + f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers) + ], + "params-decoder-layers-self_attention-key-bias": [ + f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers) + ], + "params-decoder-layers-self_attention-value-bias": [ + f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers) + ], "params-decoder-layers-self_attention-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers) ], @@ -688,6 +697,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias", + f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias", + f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias", f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", @@ -721,8 +733,8 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) return mapping -def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): - """Creates parameter transformation functions for Qwen3. +def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """Creates parameter transformation functions for Qwen. This function provides a dictionary of transformation functions (hooks) for converting Qwen3 model parameters between MaxText and Hugging Face formats. @@ -766,6 +778,15 @@ def reshape_kernel(input_tensor, target_shape): else: return input_tensor.T.reshape(target_shape) + def reshape_bias(input_tensor, target_shape=None): + """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden).""" + if saving_to_hf: + # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) + return input_tensor.reshape(target_shape) + else: + # HF [hidden_dim] -> MaxText [heads, head_dim] + return input_tensor.reshape(target_shape) + mapping = { "params-token_embedder-embedding": pad_embedding_layer, "params-decoder-logits_dense-kernel": reshape_kernel, @@ -780,6 +801,11 @@ def reshape_kernel(input_tensor, target_shape): "mlp-wi_1-kernel", "mlp-wo-kernel", ] + bias_hooks = [ + "self_attention-query-bias", + "self_attention-key-bias", + "self_attention-value-bias", + ] moe_kernel_hooks = [ "moe_block-gate-kernel", "moe_block-wi_0-kernel", @@ -793,6 +819,8 @@ def reshape_kernel(input_tensor, target_shape): if scan_layers: for key in kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel + for key in bias_hooks: + mapping[f"params-decoder-layers-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel @@ -800,6 +828,8 @@ def reshape_kernel(input_tensor, target_shape): for i in range(n_layers): for key in kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel + for key in bias_hooks: + mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel @@ -1376,7 +1406,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0) n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"] - text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING( + text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING( config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, maxtext_config=maxtext_config, scan_layers=scan_layers, @@ -1544,7 +1574,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye # Text hooks, reusing QWEN3-MOE hook function num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0) n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"] - text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN( + text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN( config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, maxtext_config=maxtext_config, scan_layers=scan_layers, @@ -2332,24 +2362,26 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2370,24 +2402,26 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,