@@ -792,6 +792,191 @@ def reshape_kernel(input_tensor, target_shape):
792792 return mapping
793793
794794
795+ def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
796+ """
797+ Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths.
798+ All MaxText keys start with 'params-' and use '-' separators for scanned layers.
799+ """
800+ num_main_layers = config ["num_hidden_layers" ]
801+ num_experts = config ["num_experts" ]
802+ layer_cycle_interval = maxtext_config .inhomogeneous_layer_cycle_interval
803+
804+ # 1. Non-layer specific weight mappings
805+ mapping = {
806+ "params-token_embedder-embedding" : "model.embed_tokens.weight" ,
807+ "params-decoder-decoder_norm-scale" : "model.norm.weight" ,
808+ "params-decoder-logits_dense-kernel" : "lm_head.weight" ,
809+ }
810+
811+ if scan_layers :
812+ # 2. Scan over block cycles
813+ for block_idx in range (layer_cycle_interval ):
814+ hf_indices = list (range (block_idx , num_main_layers , layer_cycle_interval ))
815+ prefix = f"params-decoder-layers-layer_{ block_idx } "
816+
817+ # Layer norms
818+ mapping [f"{ prefix } -input_layernorm-scale" ] = [
819+ f"model.layers.{ i } .input_layernorm.weight" for i in hf_indices
820+ ]
821+ mapping [f"{ prefix } -post_attention_layernorm-scale" ] = [
822+ f"model.layers.{ i } .post_attention_layernorm.weight" for i in hf_indices
823+ ]
824+
825+ # Handle Interleaved Attention (Linear vs Full)
826+ is_full_attention_layer = (block_idx + 1 ) % layer_cycle_interval == 0
827+
828+ if is_full_attention_layer :
829+ mapping .update ({
830+ f"{ prefix } -attention-attention-query-kernel" : [f"model.layers.{ i } .self_attn.q_proj.weight" for i in hf_indices ],
831+ f"{ prefix } -attention-attention-key-kernel" : [f"model.layers.{ i } .self_attn.k_proj.weight" for i in hf_indices ],
832+ f"{ prefix } -attention-attention-value-kernel" : [f"model.layers.{ i } .self_attn.v_proj.weight" for i in hf_indices ],
833+ f"{ prefix } -attention-attention-out-kernel" : [f"model.layers.{ i } .self_attn.o_proj.weight" for i in hf_indices ],
834+ f"{ prefix } -attention-attention-query_norm-scale" : [f"model.layers.{ i } .self_attn.q_norm.weight" for i in hf_indices ],
835+ f"{ prefix } -attention-attention-key_norm-scale" : [f"model.layers.{ i } .self_attn.k_norm.weight" for i in hf_indices ],
836+ })
837+ else :
838+ # Linear/Hybrid Attention Block
839+ mapping .update ({
840+ f"{ prefix } -attention-in_proj_qkvz-kernel" : [f"model.layers.{ i } .linear_attn.in_proj_qkvz.weight" for i in hf_indices ],
841+ f"{ prefix } -attention-in_proj_ba-kernel" : [f"model.layers.{ i } .linear_attn.in_proj_ba.weight" for i in hf_indices ],
842+ f"{ prefix } -attention-conv1d-kernel" : [f"model.layers.{ i } .linear_attn.conv1d.weight" for i in hf_indices ],
843+ f"{ prefix } -attention-A_log" : [f"model.layers.{ i } .linear_attn.A_log" for i in hf_indices ],
844+ f"{ prefix } -attention-dt_bias" : [f"model.layers.{ i } .linear_attn.dt_bias" for i in hf_indices ],
845+ f"{ prefix } -attention-norm-rms_norm-scale" : [f"model.layers.{ i } .linear_attn.norm.weight" for i in hf_indices ],
846+ f"{ prefix } -attention-out_proj-kernel" : [f"model.layers.{ i } .linear_attn.out_proj.weight" for i in hf_indices ],
847+ })
848+
849+ # 3. Handle MLP: Gates and Shared Experts
850+ mapping .update ({
851+ f"{ prefix } -mlp-routed_experts-gate-kernel" : [f"model.layers.{ i } .mlp.gate.weight" for i in hf_indices ],
852+ f"{ prefix } -mlp-shared_expert-wi_0-kernel" : [f"model.layers.{ i } .mlp.shared_expert.gate_proj.weight" for i in hf_indices ],
853+ f"{ prefix } -mlp-shared_expert-wi_1-kernel" : [f"model.layers.{ i } .mlp.shared_expert.up_proj.weight" for i in hf_indices ],
854+ f"{ prefix } -mlp-shared_expert-wo-kernel" : [f"model.layers.{ i } .mlp.shared_expert.down_proj.weight" for i in hf_indices ],
855+ f"{ prefix } -mlp-shared_expert_gate-kernel" : [f"model.layers.{ i } .mlp.shared_expert_gate.weight" for i in hf_indices ],
856+ })
857+
858+ # 4. Handle MoE Routed Experts
859+ mapping .update ({
860+ f"{ prefix } -mlp-routed_experts-wi_0" : [[f"model.layers.{ i } .mlp.experts.{ e } .gate_proj.weight" for i in hf_indices ] for e in range (num_experts )],
861+ f"{ prefix } -mlp-routed_experts-wi_1" : [[f"model.layers.{ i } .mlp.experts.{ e } .up_proj.weight" for i in hf_indices ] for e in range (num_experts )],
862+ f"{ prefix } -mlp-routed_experts-wo" : [[f"model.layers.{ i } .mlp.experts.{ e } .down_proj.weight" for i in hf_indices ] for e in range (num_experts )],
863+ })
864+ else :
865+ # Unscanned layer mapping
866+ for i in range (num_main_layers ):
867+ prefix = f"params-decoder-layers_{ i } "
868+
869+ # Layer Norms
870+ mapping [f"{ prefix } -input_layernorm-scale" ] = f"model.layers.{ i } .input_layernorm.weight"
871+ mapping [f"{ prefix } -post_attention_layernorm-scale" ] = f"model.layers.{ i } .post_attention_layernorm.weight"
872+
873+ # Determine layer type based on cycle interval
874+ # Assuming block logic: layer i corresponds to block_idx = i % interval
875+ block_idx = i % layer_cycle_interval
876+ is_full_attention_layer = (block_idx + 1 ) % layer_cycle_interval == 0
877+
878+ if is_full_attention_layer :
879+ mapping .update ({
880+ f"{ prefix } -attention-attention-query-kernel" : f"model.layers.{ i } .self_attn.q_proj.weight" ,
881+ f"{ prefix } -attention-attention-key-kernel" : f"model.layers.{ i } .self_attn.k_proj.weight" ,
882+ f"{ prefix } -attention-attention-value-kernel" : f"model.layers.{ i } .self_attn.v_proj.weight" ,
883+ f"{ prefix } -attention-attention-out-kernel" : f"model.layers.{ i } .self_attn.o_proj.weight" ,
884+ f"{ prefix } -attention-attention-query_norm-scale" : f"model.layers.{ i } .self_attn.q_norm.weight" ,
885+ f"{ prefix } -attention-attention-key_norm-scale" : f"model.layers.{ i } .self_attn.k_norm.weight" ,
886+ })
887+ else :
888+ # Linear/Hybrid Attention Block
889+ mapping .update ({
890+ f"{ prefix } -attention-in_proj_qkvz-kernel" : f"model.layers.{ i } .linear_attn.in_proj_qkvz.weight" ,
891+ f"{ prefix } -attention-in_proj_ba-kernel" : f"model.layers.{ i } .linear_attn.in_proj_ba.weight" ,
892+ f"{ prefix } -attention-conv1d-kernel" : f"model.layers.{ i } .linear_attn.conv1d.weight" ,
893+ f"{ prefix } -attention-A_log" : f"model.layers.{ i } .linear_attn.A_log" ,
894+ f"{ prefix } -attention-dt_bias" : f"model.layers.{ i } .linear_attn.dt_bias" ,
895+ f"{ prefix } -attention-norm-rms_norm-scale" : f"model.layers.{ i } .linear_attn.norm.weight" ,
896+ f"{ prefix } -attention-out_proj-kernel" : f"model.layers.{ i } .linear_attn.out_proj.weight" ,
897+ })
898+
899+ # MLP: Gates and Shared Experts
900+ mapping .update ({
901+ f"{ prefix } -mlp-routed_experts-gate-kernel" : f"model.layers.{ i } .mlp.gate.weight" ,
902+ f"{ prefix } -mlp-shared_expert-wi_0-kernel" : f"model.layers.{ i } .mlp.shared_expert.gate_proj.weight" ,
903+ f"{ prefix } -mlp-shared_expert-wi_1-kernel" : f"model.layers.{ i } .mlp.shared_expert.up_proj.weight" ,
904+ f"{ prefix } -mlp-shared_expert-wo-kernel" : f"model.layers.{ i } .mlp.shared_expert.down_proj.weight" ,
905+ f"{ prefix } -mlp-shared_expert_gate-kernel" : f"model.layers.{ i } .mlp.shared_expert_gate.weight" ,
906+ })
907+
908+ # MoE Routed Experts (List of expert weights for this specific layer)
909+ mapping .update ({
910+ f"{ prefix } -mlp-routed_experts-wi_0" : [f"model.layers.{ i } .mlp.experts.{ e } .gate_proj.weight" for e in range (num_experts )],
911+ f"{ prefix } -mlp-routed_experts-wi_1" : [f"model.layers.{ i } .mlp.experts.{ e } .up_proj.weight" for e in range (num_experts )],
912+ f"{ prefix } -mlp-routed_experts-wo" : [f"model.layers.{ i } .mlp.experts.{ e } .down_proj.weight" for e in range (num_experts )],
913+ })
914+ return mapping
915+
916+
917+ def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
918+ """
919+ Transformation hooks for parameters using hyphenated 'params-' MaxText keys.
920+ """
921+ def transpose (input_tensor , target_shape = None ):
922+ return input_tensor .T
923+
924+ def reshape_and_transpose_attn (input_tensor , target_shape = None ):
925+ if saving_to_hf :
926+ emb_dim = input_tensor .shape [0 ]
927+ return input_tensor .reshape (emb_dim , - 1 ).T
928+ else :
929+ transposed = input_tensor .T
930+ if target_shape is None :
931+ raise ValueError ("target_shape required for HF->MaxText attention conversion" )
932+ return transposed .reshape (target_shape )
933+
934+ def permute_conv (input_tensor , target_shape = None ):
935+ # MT: [K, 1, C] <-> HF: [C, 1, K]
936+ return input_tensor .transpose (2 , 1 , 0 )
937+
938+ # Initialize Hooks
939+ hooks = {
940+ "params-decoder-logits_dense-kernel" : transpose ,
941+ }
942+
943+ layer_cycle_interval = maxtext_config .inhomogeneous_layer_cycle_interval
944+ num_experts = config ["num_experts" ]
945+ num_main_layers = config ["num_hidden_layers" ]
946+ loop_indices = range (layer_cycle_interval ) if scan_layers else range (num_main_layers )
947+
948+ for i in loop_indices :
949+ if scan_layers :
950+ prefix = f"params-decoder-layers-layer_{ i } "
951+ block_idx = i
952+ else :
953+ prefix = f"params-decoder-layers_{ i } "
954+ block_idx = i % layer_cycle_interval
955+ is_full_attention_layer = (block_idx + 1 ) % layer_cycle_interval == 0
956+
957+ if is_full_attention_layer :
958+ for key in ["query" , "key" , "value" , "out" ]:
959+ hooks [f"{ prefix } -attention-attention-{ key } -kernel" ] = reshape_and_transpose_attn
960+ else :
961+ hooks [f"{ prefix } -attention-in_proj_qkvz-kernel" ] = transpose
962+ hooks [f"{ prefix } -attention-in_proj_ba-kernel" ] = transpose
963+ hooks [f"{ prefix } -attention-out_proj-kernel" ] = transpose
964+ hooks [f"{ prefix } -attention-conv1d-kernel" ] = permute_conv
965+
966+ mlp_prefix = f"{ prefix } -mlp"
967+ hooks [f"{ mlp_prefix } -routed_experts-gate-kernel" ] = transpose
968+ hooks [f"{ mlp_prefix } -shared_expert-wi_0-kernel" ] = transpose
969+ hooks [f"{ mlp_prefix } -shared_expert-wi_1-kernel" ] = transpose
970+ hooks [f"{ mlp_prefix } -shared_expert-wo-kernel" ] = transpose
971+ hooks [f"{ mlp_prefix } -shared_expert_gate-kernel" ] = transpose
972+
973+ hooks [f"{ mlp_prefix } -routed_experts-wi_0" ] = transpose
974+ hooks [f"{ mlp_prefix } -routed_experts-wi_1" ] = transpose
975+ hooks [f"{ mlp_prefix } -routed_experts-wo" ] = transpose
976+
977+ return hooks
978+
979+
795980def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
796981 """Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths.
797982
@@ -1593,6 +1778,7 @@ def scale_query_layer(input_tensor, target_shape):
15931778 "gpt-oss-20b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
15941779 "gpt-oss-120b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
15951780 "qwen3-omni-30b-a3b" : QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING ,
1781+ "qwen3-next-80b-a3b" : QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING ,
15961782 "mixtral-8x7b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING ,
15971783 "mixtral-8x22b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING ,
15981784}
@@ -1621,6 +1807,7 @@ def scale_query_layer(input_tensor, target_shape):
16211807 "gpt-oss-20b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
16221808 "gpt-oss-120b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
16231809 "qwen3-omni-30b-a3b" : QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
1810+ "qwen3-next-80b-a3b" : QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
16241811 "mixtral-8x7b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
16251812 "mixtral-8x22b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
16261813}
0 commit comments