-
Notifications
You must be signed in to change notification settings - Fork 458
Update tflops calc for qwen3 next & GatedDeltaNet #2999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a unit test in flop_calculation_test.py. Also will run the sanity checks and update the pr description with results. |
parambole
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment regarding the attention calculation. PTAL.
parambole
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! I sent a quick sync-up, hopefully this could iterate faster.
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This Pull Request introduces TFLOPs calculation for the Qwen3-Next model architecture, which includes Gated Delta Net (Linear Attention) layers and a Mixture-of-Experts (MoE) structure. The implementation correctly integrates the new calculation logic and includes a comprehensive unit test to verify its accuracy.
🔍 General Feedback
- The TFLOPs calculation logic for Gated Delta Net appears well-detailed and systematically broken down.
- The unit test
test_qwen3_next_flopsprovides good coverage by manually calculating expected FLOPs and comparing them against the implemented function.
6192638 to
8d977db
Compare
Add flop test for qwen3 next and resolve pr comments Typo in model config fixed Update flops_intra to account for inner_attn loop Update testcase with new flops_intra calculation Change if statement to less ambiguous Add comments on QKVZ tensor shapes Typo in comment fix linter issues files formatted
8d977db to
3fcbcc7
Compare
|
Manually adding pull_ready tag since all tests passed and 2 approvals recieved |
Description
PR Title: Enable TFLOPs Calculation for Qwen3-Next
Description
This PR implements the TFLOPs calculation logic for the Qwen3-Next architecture. Qwen3-Next utilizes a hybrid design containing both standard full attention layers and linear attention (Gated Delta Net) layers, alongside a Mixture-of-Experts (MoE) structure. This update ensures training TFLOPs are accurately reported for this model family.
Key Changes
1. TFLOPs Calculation Logic (
src/MaxText/maxtext_utils.py)calculate_gated_delta_net_flops_per_device: Implemented logic to calculate FLOPs for the Linear Attention layers, breaking down operations into:calculate_tflops_training_per_device:DecoderBlockType.QWEN3_NEXT.inhomogeneous_layer_cycle_interval.calculate_routed_and_shared_ffn_tflops_per_deviceandget_dense_moe_layersexplicitly supportQWEN3_NEXT.2. Unit Tests (
tests/unit/flop_calculation_test.py)test_qwen3_next_flops: A new unit test that verifies the implementation against a "golden" manual calculation for the 80B model configuration.compute_qwen3_next_attention_flops_per_device: A helper function to compute the expected attention-specific FLOPs for the test assertion.3. Config Updates (
src/MaxText/configs/models/qwen3-next-80b-a3b.yml)shared_experts: 1to the model configuration to ensure correct parameter counting in the FFN FLOPs calculation.If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/477291633
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.