-
Notifications
You must be signed in to change notification settings - Fork 458
Add Qwen3-Next to checkpoint util & update test scripts #2973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
8454259 to
866ee4e
Compare
tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh
Show resolved
Hide resolved
tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh
Show resolved
Hide resolved
tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh
Show resolved
Hide resolved
| is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0 | ||
|
|
||
| if is_full_attention_layer: | ||
| # Full Attention Block |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Adding comments explaining how these numbers relate to the config parameters (e.g., hidden_size, num_attention_heads * head_dim, etc.) or if they are fixed architectural dimensions would greatly enhance maintainability. For example, it seems 4096 = config["num_attention_heads"] * config["head_dim"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I will add how the hard coded numbers are calculated. The Gated Delta Net in particular has a bunch of these calculations.
parambole
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have left a couple of comments. PTAL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the model to conversion tool, along with careful logit checks! Left a minor comment.
- For future reference, could you also add the conversion commands to the PR description? Would be nice to also add the conversion time in description, if you have it. Thank you!
For test script 2_test_qwen3_next_80b_a3b.sh:
- Maybe also add pre-training and finetuning (example). Training was omitted from DS3 as covered by ubench.
- Could you test this script and attach log to description?
- Thanks for updating the description. Maybe update PR title as well to accurately reflect the change: e.g., add "update test scripts".
- Will this be added to XLML in the other repo?
end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
Outdated
Show resolved
Hide resolved
end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
Outdated
Show resolved
Hide resolved
end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
Outdated
Show resolved
Hide resolved
5bafc01 to
4fd58f5
Compare
4fd58f5 to
a6f97ae
Compare
Description
[Ckpt Conversion] Support Qwen3-Next in Unified Checkpoint Conversion Utility
This PR migrates the Qwen3-Next (qwen3-next-80b-a3b
) from standalone conversion scripts to the centralizedMaxText.utils.ckpt_conversion` library.Previously, Qwen3-Next relied on ad-hoc scripts for checkpointing. Moving this to the unified utility enables:
Changes
src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py: Addedqwen3_next_80b_a3b_configusingtransformers.Qwen3NextConfigand registered it inHF_MODEL_CONFIGS.src/MaxText/utils/ckpt_conversion/utils/param_mapping.py:QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING: Handles the inhomogeneous layer cycle (mapping Full Attention vs. Linear/Hybrid Attention blocks based on layer index) and MoE components (Shared vs. Routed experts).QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FNwith robust tensor handling:identityhooks) for 1D parameters likeA_log(shape[1]). This ensures that the scan axis is correctly handled during conversion (e.g., transforming to[1, 12]where appropriate) rather than incorrectly collapsing to[1,].permute_convto correctly handleconv1dkernels (HF:[C, 1, K]<-> MT:[K, 1, C]). This prevents dimensions with value1from being incorrectly squeezed or flattened during the permutation process.src/MaxText/utils/ckpt_conversion/utils/hf_shape.py: AddedQWEN3_NEXT_HF_WEIGHTS_TO_SHAPEto calculate expected HF tensor shapes for validation.end_to_end/tpu/qwen/next/...:1_test_qwen3_next_80b_a3b.shto usepython3 -m MaxText.utils.ckpt_conversion.to_maxtextinstead of the legacy script.2_test_qwen3_next_80b_a3b.shfor XLML tests to consume for forward_pass & decode verification.If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/469445683
Tests
The commands used to generate the checkpoints themselves: https://paste.googleplex.com/4921565475110912
Will run forward pass logit checker on converted checkpoint from Maxtext -> HF -> Maxtext for scanned and post results here:
Current status:
to_maxtext tests:
hf -> maxtext (scanned): https://paste.googleplex.com/5151438898593792
hf -> maxtext (unscanned): https://paste.googleplex.com/4721564912320512
to_huggingface tests:
Convert scanned & unscanned maxtext checkpoints from previous tests to hf format. Run forward_pass check against new hf checkpoints and existing maxtext checkpoints.
Maxtext (scanned) -> HF: https://paste.googleplex.com/4787924765900800
Maxtext (unscanned) -> HF: https://paste.googleplex.com/5256341314732032
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.