@@ -98,7 +98,6 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float:
9898
9999 return attention_flops / 1e12 # return tflops
100100
101-
102101 def compute_qwen3_next_attention_flops_per_device (self , kwargs : dict ) -> float :
103102 """
104103 Computes the total training TFLOPs per device for a Qwen3-Next model.
@@ -136,7 +135,6 @@ def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float:
136135
137136 return (full_attn_flops + linear_attn_flops ) / 1e12
138137
139-
140138 @pytest .mark .cpu_only
141139 def test_qwen3_next_flops (self ):
142140 """Test Qwen3-Next Flops calculation"""
@@ -148,22 +146,19 @@ def test_qwen3_next_flops(self):
148146 "decoder_block" : "qwen3_next" ,
149147 "gradient_accumulation_steps" : 1 ,
150148 "skip_jax_distributed_system" : True ,
151-
152149 # Core Architectural Parameters
153150 "base_emb_dim" : 2048 ,
154151 "base_num_decoder_layers" : 48 ,
155152 "base_num_query_heads" : 16 ,
156153 "base_num_kv_heads" : 2 ,
157154 "head_dim" : 256 ,
158155 "vocab_size" : 151936 ,
159-
160156 # MoE Parameters
161- "base_mlp_dim" : 512 , # Note: maxtext_utils uses moe_mlp_dim for calculations
157+ "base_mlp_dim" : 512 , # Note: maxtext_utils uses moe_mlp_dim for calculations
162158 "base_moe_mlp_dim" : 512 ,
163159 "num_experts" : 512 ,
164160 "num_experts_per_tok" : 10 ,
165161 "mlp_activations" : ["silu" , "linear" ],
166-
167162 # Qwen3-Next Specific Parameters
168163 "inhomogeneous_layer_cycle_interval" : 4 ,
169164 "gdn_conv_kernel_dim" : 4 ,
@@ -192,9 +187,9 @@ def test_qwen3_next_flops(self):
192187 num_experts = kwargs ["num_experts" ]
193188 num_routed = kwargs ["num_experts_per_tok" ]
194189
195- params_moe_layer = (emb_dim * num_experts ) + \
196- (3 * emb_dim * moe_mlp_dim * 1 ) + \
197- ( 3 * emb_dim * moe_mlp_dim * num_routed )
190+ params_moe_layer = (
191+ ( emb_dim * num_experts ) + ( 3 * emb_dim * moe_mlp_dim * 1 ) + (3 * emb_dim * moe_mlp_dim * num_routed )
192+ )
198193
199194 # Full Attention Params (per full layer)
200195 Hq = kwargs ["base_num_query_heads" ]
@@ -214,9 +209,9 @@ def test_qwen3_next_flops(self):
214209 V_dim = Hv_g * Dv_g
215210
216211 # Projections: qkvz (in->2K+2V), ba (in->2Hv), out (V->in)
217- params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim )) + (emb_dim * 2 * Hv_g ) + (V_dim * emb_dim )
212+ params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim )) + (emb_dim * 2 * Hv_g ) + (V_dim * emb_dim )
218213 # Conv: depthwise on 2K+V
219- params_gdn_conv = (2 * K_dim + V_dim ) * K_conv
214+ params_gdn_conv = (2 * K_dim + V_dim ) * K_conv
220215
221216 params_gdn_layer = params_gdn_proj + params_gdn_conv
222217
@@ -225,9 +220,11 @@ def test_qwen3_next_flops(self):
225220 num_full = N // kwargs ["inhomogeneous_layer_cycle_interval" ]
226221 num_linear = N - num_full
227222
228- total_active_params = (vocab * emb_dim ) + \
229- (num_full * (params_full_attn + params_moe_layer )) + \
230- (num_linear * (params_gdn_layer + params_moe_layer ))
223+ total_active_params = (
224+ (vocab * emb_dim )
225+ + (num_full * (params_full_attn + params_moe_layer ))
226+ + (num_linear * (params_gdn_layer + params_moe_layer ))
227+ )
231228
232229 # Weight TFLOPs = 6 * B * S * P
233230 B = kwargs ["per_device_batch_size" ]
@@ -245,7 +242,6 @@ def test_qwen3_next_flops(self):
245242
246243 self .assertFlopsAlmostEqual (calculated_tflops , golden_tflops )
247244
248-
249245 @pytest .mark .cpu_only
250246 def test_llama2_7b_flops (self ):
251247 """Test Llama2 7b Flops calculation with default parameters"""
0 commit comments