Skip to content

Commit 4fd58f5

Browse files
Add Qwen3-Next to checkpoint util (Squashed)
1 parent c4499d8 commit 4fd58f5

7 files changed

Lines changed: 419 additions & 38 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The following models are supported:
1616
| **Mixtral** | 8x7B, 8x22B |||||
1717
| **GPT-OSS** | 20B, 120B |||||
1818
| **DeepSeek3** | 671B | - | - || - |
19+
| **Qwen3 Next** | 80B |||||
1920

2021
## Prerequisites
2122

src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,51 @@
701701
},
702702
)
703703

704+
qwen3_next_80b_a3b_dict = {
705+
"architectures": [
706+
"Qwen3NextForCausalLM"
707+
],
708+
"attention_dropout": 0.0,
709+
"bos_token_id": 151643,
710+
"decoder_sparse_step": 1,
711+
"eos_token_id": 151645,
712+
"full_attention_interval": 4,
713+
"head_dim": 256,
714+
"hidden_act": "silu",
715+
"hidden_size": 2048,
716+
"initializer_range": 0.02,
717+
"intermediate_size": 5120,
718+
"linear_conv_kernel_dim": 4,
719+
"linear_key_head_dim": 128,
720+
"linear_num_key_heads": 16,
721+
"linear_num_value_heads": 32,
722+
"linear_value_head_dim": 128,
723+
"max_position_embeddings": 262144,
724+
"mlp_only_layers": [],
725+
"model_type": "qwen3_next",
726+
"moe_intermediate_size": 512,
727+
"norm_topk_prob": true,
728+
"num_attention_heads": 16,
729+
"num_experts": 512,
730+
"num_experts_per_tok": 10,
731+
"num_hidden_layers": 48,
732+
"num_key_value_heads": 2,
733+
"output_router_logits": false,
734+
"partial_rotary_factor": 0.25,
735+
"rms_norm_eps": 1e-06,
736+
"rope_scaling": null,
737+
"rope_theta": 10000000,
738+
"router_aux_loss_coef": 0.001,
739+
"shared_expert_intermediate_size": 512,
740+
"tie_word_embeddings": false,
741+
"torch_dtype": "bfloat16",
742+
"transformers_version": "4.57.0.dev0",
743+
"use_cache": true,
744+
"use_sliding_window": false,
745+
"vocab_size": 151936
746+
}
747+
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)
748+
704749

705750
# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
706751
mixtral_8x7b_dict = {
@@ -789,6 +834,7 @@
789834
"gpt-oss-20b": gpt_oss_20b_config,
790835
"gpt-oss-120b": gpt_oss_120b_config,
791836
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
837+
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
792838
"mixtral-8x7b": mixtral_8x7b_config,
793839
"mixtral-8x22b": mixtral_8x22b_config,
794840
}

src/MaxText/utils/ckpt_conversion/utils/hf_shape.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,87 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
349349
return mapping
350350

351351

352+
def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
353+
"""Returns mapping between HuggingFace Qwen3-Next weights path and their shape."""
354+
# --- Extract Core Config Values ---
355+
hidden_size = config["hidden_size"]
356+
num_hidden_layers = config["num_hidden_layers"]
357+
vocab_size = config["vocab_size"]
358+
num_attention_heads = config["num_attention_heads"]
359+
num_key_value_heads = config["num_key_value_heads"]
360+
intermediate_size = config["intermediate_size"]
361+
num_experts = config["num_experts"]
362+
head_dim = config["head_dim"]
363+
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
364+
linear_key_head_dim = config["linear_key_head_dim"]
365+
linear_num_key_heads = config["linear_num_key_heads"]
366+
linear_num_value_heads = config["linear_num_value_heads"]
367+
moe_intermediate_size = config["moe_intermediate_size"]
368+
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
369+
cycle_interval = config["full_attention_interval"]
370+
371+
# --- Initialize Mapping ---
372+
mapping = {
373+
"model.embed_tokens.weight": [vocab_size, hidden_size],
374+
"model.norm.weight": [hidden_size],
375+
"lm_head.weight": [vocab_size, hidden_size],
376+
}
377+
378+
for layer_idx in range(num_hidden_layers):
379+
layer_prefix = f"model.layers.{layer_idx}"
380+
381+
# Standard Layer Norms
382+
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
383+
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]
384+
385+
is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0
386+
387+
if is_full_attention_layer:
388+
# Full Attention Block
389+
mapping.update({
390+
f"{layer_prefix}.self_attn.q_proj.weight": [8192, hidden_size],
391+
f"{layer_prefix}.self_attn.k_proj.weight": [512, hidden_size],
392+
f"{layer_prefix}.self_attn.v_proj.weight": [512, hidden_size],
393+
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, 4096],
394+
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
395+
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
396+
})
397+
else:
398+
# Linear Attention (GDN) Block
399+
mapping.update({
400+
f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [12288, hidden_size],
401+
f"{layer_prefix}.linear_attn.in_proj_ba.weight": [64, hidden_size],
402+
f"{layer_prefix}.linear_attn.conv1d.weight": [8192, 1, 4],
403+
f"{layer_prefix}.linear_attn.A_log": [32],
404+
f"{layer_prefix}.linear_attn.dt_bias": [32],
405+
f"{layer_prefix}.linear_attn.norm.weight": [128],
406+
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, 4096],
407+
})
408+
409+
# --- MLP Logic (MoE + Shared) ---
410+
mapping.update({
411+
# Router
412+
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],
413+
414+
# Shared Experts (SwiGLU - Separate Weights)
415+
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
416+
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
417+
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],
418+
419+
# Shared Expert Gate (learned scaling factor)
420+
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
421+
})
422+
423+
# Routed Experts Loop
424+
# Note: HF typically stores experts as a ModuleList
425+
for e in range(num_experts):
426+
mapping.update({
427+
f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size],
428+
f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size],
429+
f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size],
430+
})
431+
432+
352433
def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
353434
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
354435
# --- Extract Core Config Values ---

src/MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,191 @@ def reshape_kernel(input_tensor, target_shape):
792792
return mapping
793793

794794

795+
def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
796+
"""
797+
Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths.
798+
All MaxText keys start with 'params-' and use '-' separators for scanned layers.
799+
"""
800+
num_main_layers = config["num_hidden_layers"]
801+
num_experts = config["num_experts"]
802+
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
803+
804+
# 1. Non-layer specific weight mappings
805+
mapping = {
806+
"params-token_embedder-embedding": "model.embed_tokens.weight",
807+
"params-decoder-decoder_norm-scale": "model.norm.weight",
808+
"params-decoder-logits_dense-kernel": "lm_head.weight",
809+
}
810+
811+
if scan_layers:
812+
# 2. Scan over block cycles
813+
for block_idx in range(layer_cycle_interval):
814+
hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval))
815+
prefix = f"params-decoder-layers-layer_{block_idx}"
816+
817+
# Layer norms
818+
mapping[f"{prefix}-input_layernorm-scale"] = [
819+
f"model.layers.{i}.input_layernorm.weight" for i in hf_indices
820+
]
821+
mapping[f"{prefix}-post_attention_layernorm-scale"] = [
822+
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
823+
]
824+
825+
# Handle Interleaved Attention (Linear vs Full)
826+
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
827+
828+
if is_full_attention_layer:
829+
mapping.update({
830+
f"{prefix}-attention-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices],
831+
f"{prefix}-attention-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices],
832+
f"{prefix}-attention-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices],
833+
f"{prefix}-attention-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices],
834+
f"{prefix}-attention-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices],
835+
f"{prefix}-attention-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices],
836+
})
837+
else:
838+
# Linear/Hybrid Attention Block
839+
mapping.update({
840+
f"{prefix}-attention-in_proj_qkvz-kernel": [f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices],
841+
f"{prefix}-attention-in_proj_ba-kernel": [f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices],
842+
f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices],
843+
f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices],
844+
f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices],
845+
f"{prefix}-attention-norm-rms_norm-scale": [f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices],
846+
f"{prefix}-attention-out_proj-kernel": [f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices],
847+
})
848+
849+
# 3. Handle MLP: Gates and Shared Experts
850+
mapping.update({
851+
f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices],
852+
f"{prefix}-mlp-shared_expert-wi_0-kernel": [f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices],
853+
f"{prefix}-mlp-shared_expert-wi_1-kernel": [f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices],
854+
f"{prefix}-mlp-shared_expert-wo-kernel": [f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices],
855+
f"{prefix}-mlp-shared_expert_gate-kernel": [f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices],
856+
})
857+
858+
# 4. Handle MoE Routed Experts
859+
mapping.update({
860+
f"{prefix}-mlp-routed_experts-wi_0": [[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)],
861+
f"{prefix}-mlp-routed_experts-wi_1": [[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)],
862+
f"{prefix}-mlp-routed_experts-wo": [[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)],
863+
})
864+
else:
865+
# Unscanned layer mapping
866+
for i in range(num_main_layers):
867+
prefix = f"params-decoder-layers_{i}"
868+
869+
# Layer Norms
870+
mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight"
871+
mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight"
872+
873+
# Determine layer type based on cycle interval
874+
# Assuming block logic: layer i corresponds to block_idx = i % interval
875+
block_idx = i % layer_cycle_interval
876+
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
877+
878+
if is_full_attention_layer:
879+
mapping.update({
880+
f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
881+
f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
882+
f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
883+
f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
884+
f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
885+
f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
886+
})
887+
else:
888+
# Linear/Hybrid Attention Block
889+
mapping.update({
890+
f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight",
891+
f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight",
892+
f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight",
893+
f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log",
894+
f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias",
895+
f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight",
896+
f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight",
897+
})
898+
899+
# MLP: Gates and Shared Experts
900+
mapping.update({
901+
f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
902+
f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight",
903+
f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight",
904+
f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight",
905+
f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight",
906+
})
907+
908+
# MoE Routed Experts (List of expert weights for this specific layer)
909+
mapping.update({
910+
f"{prefix}-mlp-routed_experts-wi_0": [f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts)],
911+
f"{prefix}-mlp-routed_experts-wi_1": [f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts)],
912+
f"{prefix}-mlp-routed_experts-wo": [f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts)],
913+
})
914+
return mapping
915+
916+
917+
def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
918+
"""
919+
Transformation hooks for parameters using hyphenated 'params-' MaxText keys.
920+
"""
921+
def transpose(input_tensor, target_shape=None):
922+
return input_tensor.T
923+
924+
def reshape_and_transpose_attn(input_tensor, target_shape=None):
925+
if saving_to_hf:
926+
emb_dim = input_tensor.shape[0]
927+
return input_tensor.reshape(emb_dim, -1).T
928+
else:
929+
transposed = input_tensor.T
930+
if target_shape is None:
931+
raise ValueError("target_shape required for HF->MaxText attention conversion")
932+
return transposed.reshape(target_shape)
933+
934+
def permute_conv(input_tensor, target_shape=None):
935+
# MT: [K, 1, C] <-> HF: [C, 1, K]
936+
return input_tensor.transpose(2, 1, 0)
937+
938+
# Initialize Hooks
939+
hooks = {
940+
"params-decoder-logits_dense-kernel": transpose,
941+
}
942+
943+
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
944+
num_experts = config["num_experts"]
945+
num_main_layers = config["num_hidden_layers"]
946+
loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers)
947+
948+
for i in loop_indices:
949+
if scan_layers:
950+
prefix = f"params-decoder-layers-layer_{i}"
951+
block_idx = i
952+
else:
953+
prefix = f"params-decoder-layers_{i}"
954+
block_idx = i % layer_cycle_interval
955+
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
956+
957+
if is_full_attention_layer:
958+
for key in ["query", "key", "value", "out"]:
959+
hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_and_transpose_attn
960+
else:
961+
hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose
962+
hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose
963+
hooks[f"{prefix}-attention-out_proj-kernel"] = transpose
964+
hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv
965+
966+
mlp_prefix = f"{prefix}-mlp"
967+
hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose
968+
hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose
969+
hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose
970+
hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose
971+
hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose
972+
973+
hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose
974+
hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose
975+
hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose
976+
977+
return hooks
978+
979+
795980
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
796981
"""Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths.
797982
@@ -1593,6 +1778,7 @@ def scale_query_layer(input_tensor, target_shape):
15931778
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
15941779
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
15951780
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
1781+
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING,
15961782
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
15971783
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
15981784
}
@@ -1621,6 +1807,7 @@ def scale_query_layer(input_tensor, target_shape):
16211807
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
16221808
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
16231809
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
1810+
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN,
16241811
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
16251812
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
16261813
}

src/MaxText/utils/ckpt_conversion/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"gpt-oss-20b": "openai/gpt-oss-20b",
8383
"gpt-oss-120b": "openai/gpt-oss-120b",
8484
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
85+
"qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct",
8586
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
8687
"mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1",
8788
}

0 commit comments

Comments
 (0)