Skip to content

Commit 6192638

Browse files
files formatted
1 parent 2171ec3 commit 6192638

1 file changed

Lines changed: 11 additions & 15 deletions

File tree

tests/unit/flop_calculation_test.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float:
9898

9999
return attention_flops / 1e12 # return tflops
100100

101-
102101
def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float:
103102
"""
104103
Computes the total training TFLOPs per device for a Qwen3-Next model.
@@ -136,7 +135,6 @@ def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float:
136135

137136
return (full_attn_flops + linear_attn_flops) / 1e12
138137

139-
140138
@pytest.mark.cpu_only
141139
def test_qwen3_next_flops(self):
142140
"""Test Qwen3-Next Flops calculation"""
@@ -148,22 +146,19 @@ def test_qwen3_next_flops(self):
148146
"decoder_block": "qwen3_next",
149147
"gradient_accumulation_steps": 1,
150148
"skip_jax_distributed_system": True,
151-
152149
# Core Architectural Parameters
153150
"base_emb_dim": 2048,
154151
"base_num_decoder_layers": 48,
155152
"base_num_query_heads": 16,
156153
"base_num_kv_heads": 2,
157154
"head_dim": 256,
158155
"vocab_size": 151936,
159-
160156
# MoE Parameters
161-
"base_mlp_dim": 512, # Note: maxtext_utils uses moe_mlp_dim for calculations
157+
"base_mlp_dim": 512, # Note: maxtext_utils uses moe_mlp_dim for calculations
162158
"base_moe_mlp_dim": 512,
163159
"num_experts": 512,
164160
"num_experts_per_tok": 10,
165161
"mlp_activations": ["silu", "linear"],
166-
167162
# Qwen3-Next Specific Parameters
168163
"inhomogeneous_layer_cycle_interval": 4,
169164
"gdn_conv_kernel_dim": 4,
@@ -192,9 +187,9 @@ def test_qwen3_next_flops(self):
192187
num_experts = kwargs["num_experts"]
193188
num_routed = kwargs["num_experts_per_tok"]
194189

195-
params_moe_layer = (emb_dim * num_experts) + \
196-
(3 * emb_dim * moe_mlp_dim * 1) + \
197-
(3 * emb_dim * moe_mlp_dim * num_routed)
190+
params_moe_layer = (
191+
(emb_dim * num_experts) + (3 * emb_dim * moe_mlp_dim * 1) + (3 * emb_dim * moe_mlp_dim * num_routed)
192+
)
198193

199194
# Full Attention Params (per full layer)
200195
Hq = kwargs["base_num_query_heads"]
@@ -214,9 +209,9 @@ def test_qwen3_next_flops(self):
214209
V_dim = Hv_g * Dv_g
215210

216211
# Projections: qkvz (in->2K+2V), ba (in->2Hv), out (V->in)
217-
params_gdn_proj = (emb_dim * (2*K_dim + 2*V_dim)) + (emb_dim * 2*Hv_g) + (V_dim * emb_dim)
212+
params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim)) + (emb_dim * 2 * Hv_g) + (V_dim * emb_dim)
218213
# Conv: depthwise on 2K+V
219-
params_gdn_conv = (2*K_dim + V_dim) * K_conv
214+
params_gdn_conv = (2 * K_dim + V_dim) * K_conv
220215

221216
params_gdn_layer = params_gdn_proj + params_gdn_conv
222217

@@ -225,9 +220,11 @@ def test_qwen3_next_flops(self):
225220
num_full = N // kwargs["inhomogeneous_layer_cycle_interval"]
226221
num_linear = N - num_full
227222

228-
total_active_params = (vocab * emb_dim) + \
229-
(num_full * (params_full_attn + params_moe_layer)) + \
230-
(num_linear * (params_gdn_layer + params_moe_layer))
223+
total_active_params = (
224+
(vocab * emb_dim)
225+
+ (num_full * (params_full_attn + params_moe_layer))
226+
+ (num_linear * (params_gdn_layer + params_moe_layer))
227+
)
231228

232229
# Weight TFLOPs = 6 * B * S * P
233230
B = kwargs["per_device_batch_size"]
@@ -245,7 +242,6 @@ def test_qwen3_next_flops(self):
245242

246243
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
247244

248-
249245
@pytest.mark.cpu_only
250246
def test_llama2_7b_flops(self):
251247
"""Test Llama2 7b Flops calculation with default parameters"""

0 commit comments

Comments
 (0)