Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following models are supported:
| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ |
| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ |
| **DeepSeek3** | 671B | - | - | √ | - |
| **Qwen3 Next** | 80B | √ | √ | √ | √ |

## Prerequisites

Expand Down
46 changes: 46 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,51 @@
},
)

qwen3_next_80b_a3b_dict = {
"architectures": [
"Qwen3NextForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5120,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 32,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_next",
"moe_intermediate_size": 512,
"norm_topk_prob": true,
"num_attention_heads": 16,
"num_experts": 512,
"num_experts_per_tok": 10,
"num_hidden_layers": 48,
"num_key_value_heads": 2,
"output_router_logits": false,
"partial_rotary_factor": 0.25,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000000,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 512,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.57.0.dev0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)


# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
mixtral_8x7b_dict = {
Expand Down Expand Up @@ -789,6 +834,7 @@
"gpt-oss-20b": gpt_oss_20b_config,
"gpt-oss-120b": gpt_oss_120b_config,
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
"mixtral-8x7b": mixtral_8x7b_config,
"mixtral-8x22b": mixtral_8x22b_config,
}
81 changes: 81 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,87 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
return mapping


def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen3-Next weights path and their shape."""
# --- Extract Core Config Values ---
hidden_size = config["hidden_size"]
num_hidden_layers = config["num_hidden_layers"]
vocab_size = config["vocab_size"]
num_attention_heads = config["num_attention_heads"]
num_key_value_heads = config["num_key_value_heads"]
intermediate_size = config["intermediate_size"]
num_experts = config["num_experts"]
head_dim = config["head_dim"]
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
linear_key_head_dim = config["linear_key_head_dim"]
linear_num_key_heads = config["linear_num_key_heads"]
linear_num_value_heads = config["linear_num_value_heads"]
moe_intermediate_size = config["moe_intermediate_size"]
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
cycle_interval = config["full_attention_interval"]

# --- Initialize Mapping ---
mapping = {
"model.embed_tokens.weight": [vocab_size, hidden_size],
"model.norm.weight": [hidden_size],
"lm_head.weight": [vocab_size, hidden_size],
}

for layer_idx in range(num_hidden_layers):
layer_prefix = f"model.layers.{layer_idx}"

# Standard Layer Norms
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]

is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0

if is_full_attention_layer:
# Full Attention Block
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Adding comments explaining how these numbers relate to the config parameters (e.g., hidden_size, num_attention_heads * head_dim, etc.) or if they are fixed architectural dimensions would greatly enhance maintainability. For example, it seems 4096 = config["num_attention_heads"] * config["head_dim"]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will add how the hard coded numbers are calculated. The Gated Delta Net in particular has a bunch of these calculations.

mapping.update({
f"{layer_prefix}.self_attn.q_proj.weight": [8192, hidden_size],
f"{layer_prefix}.self_attn.k_proj.weight": [512, hidden_size],
f"{layer_prefix}.self_attn.v_proj.weight": [512, hidden_size],
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, 4096],
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
})
else:
# Linear Attention (GDN) Block
mapping.update({
f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [12288, hidden_size],
f"{layer_prefix}.linear_attn.in_proj_ba.weight": [64, hidden_size],
f"{layer_prefix}.linear_attn.conv1d.weight": [8192, 1, 4],
f"{layer_prefix}.linear_attn.A_log": [32],
f"{layer_prefix}.linear_attn.dt_bias": [32],
f"{layer_prefix}.linear_attn.norm.weight": [128],
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, 4096],
})

# --- MLP Logic (MoE + Shared) ---
mapping.update({
# Router
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],

# Shared Experts (SwiGLU - Separate Weights)
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],

# Shared Expert Gate (learned scaling factor)
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
})

# Routed Experts Loop
# Note: HF typically stores experts as a ModuleList
for e in range(num_experts):
mapping.update({
f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size],
})


def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
# --- Extract Core Config Values ---
Expand Down
187 changes: 187 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,191 @@ def reshape_kernel(input_tensor, target_shape):
return mapping


def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""
Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths.
All MaxText keys start with 'params-' and use '-' separators for scanned layers.
"""
num_main_layers = config["num_hidden_layers"]
num_experts = config["num_experts"]
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval

# 1. Non-layer specific weight mappings
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}

if scan_layers:
# 2. Scan over block cycles
for block_idx in range(layer_cycle_interval):
hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval))
prefix = f"params-decoder-layers-layer_{block_idx}"

# Layer norms
mapping[f"{prefix}-input_layernorm-scale"] = [
f"model.layers.{i}.input_layernorm.weight" for i in hf_indices
]
mapping[f"{prefix}-post_attention_layernorm-scale"] = [
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
]

# Handle Interleaved Attention (Linear vs Full)
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0

if is_full_attention_layer:
mapping.update({
f"{prefix}-attention-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices],
f"{prefix}-attention-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices],
})
else:
# Linear/Hybrid Attention Block
mapping.update({
f"{prefix}-attention-in_proj_qkvz-kernel": [f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices],
f"{prefix}-attention-in_proj_ba-kernel": [f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices],
f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices],
f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices],
f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices],
f"{prefix}-attention-norm-rms_norm-scale": [f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices],
f"{prefix}-attention-out_proj-kernel": [f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices],
})

# 3. Handle MLP: Gates and Shared Experts
mapping.update({
f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wi_0-kernel": [f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wi_1-kernel": [f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wo-kernel": [f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert_gate-kernel": [f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices],
})

# 4. Handle MoE Routed Experts
mapping.update({
f"{prefix}-mlp-routed_experts-wi_0": [[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)],
f"{prefix}-mlp-routed_experts-wi_1": [[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)],
f"{prefix}-mlp-routed_experts-wo": [[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)],
})
else:
# Unscanned layer mapping
for i in range(num_main_layers):
prefix = f"params-decoder-layers_{i}"

# Layer Norms
mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight"
mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight"

# Determine layer type based on cycle interval
# Assuming block logic: layer i corresponds to block_idx = i % interval
block_idx = i % layer_cycle_interval
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0

if is_full_attention_layer:
mapping.update({
f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
})
else:
# Linear/Hybrid Attention Block
mapping.update({
f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight",
f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight",
f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight",
f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log",
f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias",
f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight",
f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight",
})

# MLP: Gates and Shared Experts
mapping.update({
f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight",
f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight",
f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight",
f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight",
})

# MoE Routed Experts (List of expert weights for this specific layer)
mapping.update({
f"{prefix}-mlp-routed_experts-wi_0": [f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts)],
f"{prefix}-mlp-routed_experts-wi_1": [f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts)],
f"{prefix}-mlp-routed_experts-wo": [f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts)],
})
return mapping


def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""
Transformation hooks for parameters using hyphenated 'params-' MaxText keys.
"""
def transpose(input_tensor, target_shape=None):
return input_tensor.T

def reshape_and_transpose_attn(input_tensor, target_shape=None):
if saving_to_hf:
emb_dim = input_tensor.shape[0]
return input_tensor.reshape(emb_dim, -1).T
else:
transposed = input_tensor.T
if target_shape is None:
raise ValueError("target_shape required for HF->MaxText attention conversion")
return transposed.reshape(target_shape)

def permute_conv(input_tensor, target_shape=None):
# MT: [K, 1, C] <-> HF: [C, 1, K]
return input_tensor.transpose(2, 1, 0)

# Initialize Hooks
hooks = {
"params-decoder-logits_dense-kernel": transpose,
}

layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
num_experts = config["num_experts"]
num_main_layers = config["num_hidden_layers"]
loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers)

for i in loop_indices:
if scan_layers:
prefix = f"params-decoder-layers-layer_{i}"
block_idx = i
else:
prefix = f"params-decoder-layers_{i}"
block_idx = i % layer_cycle_interval
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0

if is_full_attention_layer:
for key in ["query", "key", "value", "out"]:
hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_and_transpose_attn
else:
hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose
hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose
hooks[f"{prefix}-attention-out_proj-kernel"] = transpose
hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv

mlp_prefix = f"{prefix}-mlp"
hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose

hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose

return hooks


def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths.

Expand Down Expand Up @@ -1593,6 +1778,7 @@ def scale_query_layer(input_tensor, target_shape):
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING,
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
}
Expand Down Expand Up @@ -1621,6 +1807,7 @@ def scale_query_layer(input_tensor, target_shape):
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
}
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/utils/ckpt_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"gpt-oss-20b": "openai/gpt-oss-20b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct",
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1",
}
Expand Down
Loading
Loading