From a6f97aefa4bc989fdedbc4fed1528c717b042e32 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Thu, 29 Jan 2026 18:00:32 +0000 Subject: [PATCH] Add Qwen3-Next to checkpoint util (Squashed) --- .../convert_checkpoint.md | 1 + .../ckpt_conversion/utils/hf_model_configs.py | 46 +++++ .../utils/ckpt_conversion/utils/hf_shape.py | 81 ++++++++ .../ckpt_conversion/utils/param_mapping.py | 187 ++++++++++++++++++ .../utils/ckpt_conversion/utils/utils.py | 1 + .../1_test_qwen3_next_80b_a3b.sh | 76 +++---- .../2_test_qwen3_next_80b_a3b.sh | 65 ++++++ 7 files changed, 419 insertions(+), 38 deletions(-) create mode 100644 tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index b37d2923c8..db60d58cbb 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -16,6 +16,7 @@ The following models are supported: | **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | | **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | | **DeepSeek3** | 671B | - | - | √ | - | +| **Qwen3 Next** | 80B | √ | √ | √ | √ | ## Prerequisites diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py index d91b7987ca..176b3da6af 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py @@ -701,6 +701,51 @@ }, ) +qwen3_next_80b_a3b_dict = { + "architectures": [ + "Qwen3NextForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5120, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "linear_value_head_dim": 128, + "max_position_embeddings": 262144, + "mlp_only_layers": [], + "model_type": "qwen3_next", + "moe_intermediate_size": 512, + "norm_topk_prob": true, + "num_attention_heads": 16, + "num_experts": 512, + "num_experts_per_tok": 10, + "num_hidden_layers": 48, + "num_key_value_heads": 2, + "output_router_logits": false, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000000, + "router_aux_loss_coef": 0.001, + "shared_expert_intermediate_size": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.0.dev0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} +qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict) + # from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json mixtral_8x7b_dict = { @@ -789,6 +834,7 @@ "gpt-oss-20b": gpt_oss_20b_config, "gpt-oss-120b": gpt_oss_120b_config, "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config, + "qwen3-next-80b-a3b": qwen3_next_80b_a3b_config, "mixtral-8x7b": mixtral_8x7b_config, "mixtral-8x22b": mixtral_8x22b_config, } diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py index 081017dd96..59e670a3e1 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py @@ -349,6 +349,87 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): return mapping +def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config): + """Returns mapping between HuggingFace Qwen3-Next weights path and their shape.""" + # --- Extract Core Config Values --- + hidden_size = config["hidden_size"] + num_hidden_layers = config["num_hidden_layers"] + vocab_size = config["vocab_size"] + num_attention_heads = config["num_attention_heads"] + num_key_value_heads = config["num_key_value_heads"] + intermediate_size = config["intermediate_size"] + num_experts = config["num_experts"] + head_dim = config["head_dim"] + linear_conv_kernel_dim = config["linear_conv_kernel_dim"] + linear_key_head_dim = config["linear_key_head_dim"] + linear_num_key_heads = config["linear_num_key_heads"] + linear_num_value_heads = config["linear_num_value_heads"] + moe_intermediate_size = config["moe_intermediate_size"] + shared_expert_intermediate_size = config["shared_expert_intermediate_size"] + cycle_interval = config["full_attention_interval"] + + # --- Initialize Mapping --- + mapping = { + "model.embed_tokens.weight": [vocab_size, hidden_size], + "model.norm.weight": [hidden_size], + "lm_head.weight": [vocab_size, hidden_size], + } + + for layer_idx in range(num_hidden_layers): + layer_prefix = f"model.layers.{layer_idx}" + + # Standard Layer Norms + mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size] + mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size] + + is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0 + + if is_full_attention_layer: + # Full Attention Block + mapping.update({ + f"{layer_prefix}.self_attn.q_proj.weight": [8192, hidden_size], + f"{layer_prefix}.self_attn.k_proj.weight": [512, hidden_size], + f"{layer_prefix}.self_attn.v_proj.weight": [512, hidden_size], + f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, 4096], + f"{layer_prefix}.self_attn.q_norm.weight": [head_dim], + f"{layer_prefix}.self_attn.k_norm.weight": [head_dim], + }) + else: + # Linear Attention (GDN) Block + mapping.update({ + f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [12288, hidden_size], + f"{layer_prefix}.linear_attn.in_proj_ba.weight": [64, hidden_size], + f"{layer_prefix}.linear_attn.conv1d.weight": [8192, 1, 4], + f"{layer_prefix}.linear_attn.A_log": [32], + f"{layer_prefix}.linear_attn.dt_bias": [32], + f"{layer_prefix}.linear_attn.norm.weight": [128], + f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, 4096], + }) + + # --- MLP Logic (MoE + Shared) --- + mapping.update({ + # Router + f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size], + + # Shared Experts (SwiGLU - Separate Weights) + f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size], + + # Shared Expert Gate (learned scaling factor) + f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size], + }) + + # Routed Experts Loop + # Note: HF typically stores experts as a ModuleList + for e in range(num_experts): + mapping.update({ + f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size], + }) + + def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace GptOss weights path and their shape.""" # --- Extract Core Config Values --- diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index a38e77c8f4..3a9b8281a5 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -792,6 +792,191 @@ def reshape_kernel(input_tensor, target_shape): return mapping +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """ + Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths. + All MaxText keys start with 'params-' and use '-' separators for scanned layers. + """ + num_main_layers = config["num_hidden_layers"] + num_experts = config["num_experts"] + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + + # 1. Non-layer specific weight mappings + mapping = { + "params-token_embedder-embedding": "model.embed_tokens.weight", + "params-decoder-decoder_norm-scale": "model.norm.weight", + "params-decoder-logits_dense-kernel": "lm_head.weight", + } + + if scan_layers: + # 2. Scan over block cycles + for block_idx in range(layer_cycle_interval): + hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval)) + prefix = f"params-decoder-layers-layer_{block_idx}" + + # Layer norms + mapping[f"{prefix}-input_layernorm-scale"] = [ + f"model.layers.{i}.input_layernorm.weight" for i in hf_indices + ] + mapping[f"{prefix}-post_attention_layernorm-scale"] = [ + f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices + ] + + # Handle Interleaved Attention (Linear vs Full) + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update({ + f"{prefix}-attention-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices], + f"{prefix}-attention-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices], + }) + else: + # Linear/Hybrid Attention Block + mapping.update({ + f"{prefix}-attention-in_proj_qkvz-kernel": [f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices], + f"{prefix}-attention-in_proj_ba-kernel": [f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices], + f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices], + f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices], + f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices], + f"{prefix}-attention-norm-rms_norm-scale": [f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices], + f"{prefix}-attention-out_proj-kernel": [f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices], + }) + + # 3. Handle MLP: Gates and Shared Experts + mapping.update({ + f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wi_0-kernel": [f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wi_1-kernel": [f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wo-kernel": [f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert_gate-kernel": [f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices], + }) + + # 4. Handle MoE Routed Experts + mapping.update({ + f"{prefix}-mlp-routed_experts-wi_0": [[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)], + f"{prefix}-mlp-routed_experts-wi_1": [[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)], + f"{prefix}-mlp-routed_experts-wo": [[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)], + }) + else: + # Unscanned layer mapping + for i in range(num_main_layers): + prefix = f"params-decoder-layers_{i}" + + # Layer Norms + mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight" + mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight" + + # Determine layer type based on cycle interval + # Assuming block logic: layer i corresponds to block_idx = i % interval + block_idx = i % layer_cycle_interval + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update({ + f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", + f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", + f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", + f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", + f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", + }) + else: + # Linear/Hybrid Attention Block + mapping.update({ + f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight", + f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight", + f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight", + f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log", + f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias", + f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight", + f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight", + }) + + # MLP: Gates and Shared Experts + mapping.update({ + f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight", + f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight", + f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight", + f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight", + f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight", + }) + + # MoE Routed Experts (List of expert weights for this specific layer) + mapping.update({ + f"{prefix}-mlp-routed_experts-wi_0": [f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts)], + f"{prefix}-mlp-routed_experts-wi_1": [f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts)], + f"{prefix}-mlp-routed_experts-wo": [f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts)], + }) + return mapping + + +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """ + Transformation hooks for parameters using hyphenated 'params-' MaxText keys. + """ + def transpose(input_tensor, target_shape=None): + return input_tensor.T + + def reshape_and_transpose_attn(input_tensor, target_shape=None): + if saving_to_hf: + emb_dim = input_tensor.shape[0] + return input_tensor.reshape(emb_dim, -1).T + else: + transposed = input_tensor.T + if target_shape is None: + raise ValueError("target_shape required for HF->MaxText attention conversion") + return transposed.reshape(target_shape) + + def permute_conv(input_tensor, target_shape=None): + # MT: [K, 1, C] <-> HF: [C, 1, K] + return input_tensor.transpose(2, 1, 0) + + # Initialize Hooks + hooks = { + "params-decoder-logits_dense-kernel": transpose, + } + + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + num_experts = config["num_experts"] + num_main_layers = config["num_hidden_layers"] + loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers) + + for i in loop_indices: + if scan_layers: + prefix = f"params-decoder-layers-layer_{i}" + block_idx = i + else: + prefix = f"params-decoder-layers_{i}" + block_idx = i % layer_cycle_interval + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + for key in ["query", "key", "value", "out"]: + hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_and_transpose_attn + else: + hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose + hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose + hooks[f"{prefix}-attention-out_proj-kernel"] = transpose + hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv + + mlp_prefix = f"{prefix}-mlp" + hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose + + hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose + + return hooks + + def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths. @@ -1593,6 +1778,7 @@ def scale_query_layer(input_tensor, target_shape): "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, } @@ -1621,6 +1807,7 @@ def scale_query_layer(input_tensor, target_shape): "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, } diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index 42eb439539..785deec1de 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -82,6 +82,7 @@ "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", } diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh index d444fea988..325ebe76af 100644 --- a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh @@ -1,8 +1,11 @@ #!/bin/bash -# This script validates a pre-converted MaxText checkpoint against its original -# HuggingFace counterpart to ensure numerical correctness. +# This file is documentation for how to get started with Qwen3 Next. +# This file runs Step 1 on CPU. +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pre-training, fine-tuning, and decoding. # --- # Example Usage: # @@ -17,43 +20,40 @@ set -ex -# --- Configuration & Input Validation --- +export MODEL_NAME='qwen3-next-80b-a3b' +export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct' -if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then - echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set." - echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights." - exit 1 -fi +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# Set a default for the HF model path if it's not provided by the user -if [ -z "${HF_MODEL_PATH}" ]; then - export HF_MODEL_PATH="Qwen/Qwen3-Next-80B-A3B-Instruct" - echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}" +# Ensure HF_TOKEN is set +if [ -z "${HF_TOKEN}" ]; then + echo "Error: HF_TOKEN environment variable is not set. Please export your Hugging Face token." + echo "Example: export HF_TOKEN=hf_..." + exit 1 fi -# Install dependencies required for the logit checker. -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - -# --- Run the Forward Pass Logit Checker --- - -echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" -echo "Against original HF model: ${HF_MODEL_PATH}" - -# This command runs the core validation logic. -JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ - tokenizer_type=huggingface \ - tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ - megablox=False \ - sparse_matmul=False \ - load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \ - model_name=qwen3-next-80b-a3b \ - checkpoint_storage_concurrent_gb=1024 \ - skip_jax_distributed_system=True \ - dtype=float32 \ - weight_dtype=float32 \ - matmul_precision=highest \ - --hf_model_path=${HF_MODEL_PATH} \ - --max_kl_div=0.03 \ - --run_hf_model=True - -echo "Validation complete." +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# 1.1 Convert checkpoint to `scanned` format, more suitable for training +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ + model_name=qwen3-next-80b-a3b \ + base_output_directory=${BASE_OUTPUT_PATH}/scanned \ + hf_access_token=${HF_TOKEN} \ + scan_layers=true \ + use_multimodal=false + +# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ + model_name=qwen3-next-80b-a3b \ + base_output_directory=${BASE_OUTPUT_PATH}/unscanned \ + hf_access_token=${HF_TOKEN} \ + scan_layers=false \ + use_multimodal=false \ No newline at end of file diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh new file mode 100644 index 0000000000..0c1321d08b --- /dev/null +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# This file is documentation for how to get started with Qwen3 Next. + +# This file runs Step 2 on v5p-128 on a daily basis. +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pretraining, finetuning, and decoding. + +# The golden logit can be generated by: +# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=Qwen/Qwen3-Next-80B-A3B-Instruct --output-path=golden_data_qwen3-next-80b-a3b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16 + +set -ex + +export PYTHONPATH=$PYTHONPATH:$(pwd)/src + +export MODEL_NAME='qwen3-next-80b-a3b' +export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct' + +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# e.g., $HOME/maxtext/src/MaxText +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# Step 2: +# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands +# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items +# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items +# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test. +SCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/scanned/0/items +UNSCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/unscanned/0/items +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset + +# Test whether the forward pass logits match the golden logits +# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl +GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl" +if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then + GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl" + GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl + gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} +fi + +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.1 + +# Run pre-training - tokamax_gmm implementation +python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024 + +# Run fine-tuning - tokamax_gmm implementation +python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024 checkpoint_storage_concurrent_gb=1024 + + +# Run decoding - tokamax_gmm implementation +# Note decode requires the access token for huggingface tokenizer even if the model is not gated +python3 -m MaxText.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "