diff --git a/src/MaxText/configs/models/qwen3-next-80b-a3b.yml b/src/MaxText/configs/models/qwen3-next-80b-a3b.yml index f48d7da34d..6f362ba4f5 100644 --- a/src/MaxText/configs/models/qwen3-next-80b-a3b.yml +++ b/src/MaxText/configs/models/qwen3-next-80b-a3b.yml @@ -31,6 +31,7 @@ normalization_layer_epsilon: 1.0e-6 base_mlp_dim: 512 base_moe_mlp_dim: 512 num_experts: 512 +shared_experts: 1 num_experts_per_tok: 10 norm_topk_prob: True diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 89a93f6098..b9df72c1d9 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -488,12 +488,82 @@ def get_dense_moe_layers(config): elif config.decoder_block == DecoderBlockType.LLAMA4: num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step num_dense_layers = config.num_decoder_layers - num_moe_layers + elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: + num_moe_layers = config.num_decoder_layers + num_dense_layers = 0 else: - raise ValueError("Currently we only support DeepSeek and Llama4 calculation.") + raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.") return num_dense_layers, num_moe_layers +def calculate_gated_delta_net_flops_per_device(config): + """Calculates the FLOPs for a single Gated Delta Net (Linear Attention) layer.""" + B = config.per_device_batch_size + S = config.max_target_length + E = config.emb_dim + + H_k = config.gdn_num_key_heads + H_v = config.gdn_num_value_heads + D_k = config.gdn_key_head_dim + D_v = config.gdn_value_head_dim + C = config.gdn_chunk_size + K_conv = config.gdn_conv_kernel_dim + + K_dim = H_k * D_k + V_dim = H_v * D_v + + # 1. Projections (Learnable Weights) + # in_proj_qkvz: E -> 2*K_dim + 2*V_dim + flops_qkvz = 2 * B * S * E * (2 * K_dim + 2 * V_dim) + # in_proj_ba: E -> 2*H_v + flops_ba = 2 * B * S * E * (2 * H_v) + # out_proj: V_dim -> E + flops_out = 2 * B * S * V_dim * E + + flops_projections = flops_qkvz + flops_ba + flops_out + + # 2. Convolution (Learnable Weights) + # Depthwise conv on dim (2*K_dim + V_dim) + # 2 * B * S * Channels * Kernel + flops_conv = 2 * B * S * (2 * K_dim + V_dim) * K_conv + + # 3. Core Gated Delta Net (Attention-like operations) + # Assumptions: + # H = H_v (broadcasting K to V heads if H_v > H_k) + # N = num_chunks & N * C ~ S + # + # Query (Q): [B, S, H_v, D_k] + # Keys (K): [B, S, H_v, D_k] + # Values (V): [B, S, H_v, D_v] + # Intra-Chunk Attention (A): [B, N, H_v, C, C] + # Recurrent State (S): [B, N, H_v, D_k, D_v] + + # - Intra-chunk terms (per chunk C): + # - attn (K*K): 2 * B * S * H_v * C * D_k + # - val_intra (A*V): 2 * B * S * H_v * C * D_v + # - k_cum (A*K): 2 * B * S * H_v * C * D_k + # - inner_attn_body loop (iterative refinement): ≈ (C - 1) * B * H * N * C^2 ≈ B * H * S * C^2 + flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2) + + # - Inter-chunk terms (Recurrent State D_k * D_v): + # - attn_i (Q*K): 2 * B * S * H_v * C * D_k + # - v_prime (K*S): 2 * B * S * H_v * D_k * D_v + # - attn_inter (Q*S): 2 * B * S * H_v * D_k * D_v + # - core_out (A*V): 2 * B * S * H_v * C * D_v + # - update (K*V): 2 * B * S * H_v * D_k * D_v + flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v) + + flops_core = flops_intra + flops_inter + + # Weights part: Projections + Conv + gdn_weight_flops = flops_projections + flops_conv + # Attention part: Core + gdn_attn_flops = flops_core + + return gdn_weight_flops, gdn_attn_flops + + def calculate_gemma3_vision_layers_tflops_per_device(config): """ Estimate TFLOPs for Gemma3 vision encoder (ViT-style). @@ -634,7 +704,7 @@ def calculate_tflops_training_per_device(config, log=True): # MLP flops if config.num_experts > 1: # calculation based on dropless implementation - if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4): + if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT): total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config) else: gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts @@ -702,6 +772,24 @@ def calculate_tflops_training_per_device(config, log=True): (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 ) attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 + elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: + gdn_weight_flops_per_layer, gdn_attn_flops_per_layer = calculate_gated_delta_net_flops_per_device(config) + cycle_interval = config.inhomogeneous_layer_cycle_interval + num_full_attn_layers = config.num_decoder_layers // cycle_interval + num_linear_attn_layers = config.num_decoder_layers - num_full_attn_layers + + # Weights TFLOPs: + total_weights = ( + total_ffn_flops + + embedding_flops + + (qkv_flops + projection_flops) * num_full_attn_layers + + gdn_weight_flops_per_layer * num_linear_attn_layers + ) + learnable_weight_tflops = total_weights * 3 / 10**12 + + # Attention TFLOPs: + total_attn = (causal_attention_flops * num_full_attn_layers) + (gdn_attn_flops_per_layer * num_linear_attn_layers) + attention_tflops = total_attn * 3 / 10**12 else: # multiply by 3 for both feed forward and back propagation flops learnable_weight_tflops = ( diff --git a/tests/unit/flop_calculation_test.py b/tests/unit/flop_calculation_test.py index d147f71ff1..7b0de4317f 100644 --- a/tests/unit/flop_calculation_test.py +++ b/tests/unit/flop_calculation_test.py @@ -98,6 +98,150 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float: return attention_flops / 1e12 # return tflops + def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float: + """ + Computes the total training TFLOPs per device for a Qwen3-Next model. + Only counts the attention mechanism operations (non-weights). + """ + B = kwargs["per_device_batch_size"] + S = kwargs["max_target_length"] + N = kwargs["base_num_decoder_layers"] + cycle_interval = kwargs["inhomogeneous_layer_cycle_interval"] + + # Layer counts + num_full_layers = N // cycle_interval + num_linear_layers = N - num_full_layers + + # 1. Full Attention FLOPs (Causal) + D_head = kwargs["head_dim"] + H_q = kwargs["base_num_query_heads"] + # 2 for QK^T and SV, 3 for fwd+bwd. + # Note: maxtext_utils divides by 2 for causal masking. + # Formula: 2 * 3 * B * S^2 * H * D + full_attn_flops = 2 * 3 * num_full_layers * B * (S**2) * H_q * D_head + + # 2. Linear Attention (Gated Delta Net) FLOPs + H_v = kwargs["gdn_num_value_heads"] + D_k = kwargs["gdn_key_head_dim"] + D_v = kwargs["gdn_value_head_dim"] + C = kwargs["gdn_chunk_size"] + + # Formulas from maxtext_utils.calculate_gated_delta_net_flops_per_device + flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2) + flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v) + + # 3 for fwd+bwd + linear_attn_flops = 3 * num_linear_layers * (flops_intra + flops_inter) + + return (full_attn_flops + linear_attn_flops) / 1e12 + + @pytest.mark.cpu_only + def test_qwen3_next_flops(self): + """Test Qwen3-Next Flops calculation""" + kwargs = { + "model_name": "qwen3-next-80b-a3b", + "override_model_config": True, + "per_device_batch_size": 1, + "max_target_length": 4096, + "decoder_block": "qwen3_next", + "gradient_accumulation_steps": 1, + "skip_jax_distributed_system": True, + # Core Architectural Parameters + "base_emb_dim": 2048, + "base_num_decoder_layers": 48, + "base_num_query_heads": 16, + "base_num_kv_heads": 2, + "head_dim": 256, + "vocab_size": 151936, + # MoE Parameters + "base_mlp_dim": 512, # Note: maxtext_utils uses moe_mlp_dim for calculations + "base_moe_mlp_dim": 512, + "num_experts": 512, + "num_experts_per_tok": 10, + "mlp_activations": ["silu", "linear"], + # Qwen3-Next Specific Parameters + "inhomogeneous_layer_cycle_interval": 4, + "gdn_conv_kernel_dim": 4, + "gdn_key_head_dim": 128, + "gdn_value_head_dim": 128, + "gdn_num_key_heads": 16, + "gdn_num_value_heads": 32, + "gdn_chunk_size": 64, + } + + # 1. Calculate Attention TFLOPs + attention_tflops = self.compute_qwen3_next_attention_flops_per_device(kwargs) + + # 2. Calculate Learnable Weight Active Params + # Config Shortcuts + emb_dim = kwargs["base_emb_dim"] + vocab = kwargs["vocab_size"] + N = kwargs["base_num_decoder_layers"] + + # MoE Active Params (per layer) + # FFN uses SwiGLU (3 matrices), Qwen3-Next has 1 shared + N routed experts + # Params = Gate + Shared + Routed + # Gate: emb_dim * num_experts + # Expert: 3 * emb_dim * moe_mlp_dim + moe_mlp_dim = kwargs["base_moe_mlp_dim"] + num_experts = kwargs["num_experts"] + num_routed = kwargs["num_experts_per_tok"] + + params_moe_layer = ( + (emb_dim * num_experts) + (3 * emb_dim * moe_mlp_dim * 1) + (3 * emb_dim * moe_mlp_dim * num_routed) + ) + + # Full Attention Params (per full layer) + Hq = kwargs["base_num_query_heads"] + Hkv = kwargs["base_num_kv_heads"] + Hd = kwargs["head_dim"] + # Q, K, V, Out projections + params_full_attn = (emb_dim * (Hq + 2 * Hkv) * Hd) + (Hq * Hd * emb_dim) + + # GDN Linear Attention Params (per linear layer) + Hk_g = kwargs["gdn_num_key_heads"] + Hv_g = kwargs["gdn_num_value_heads"] + Dk_g = kwargs["gdn_key_head_dim"] + Dv_g = kwargs["gdn_value_head_dim"] + K_conv = kwargs["gdn_conv_kernel_dim"] + + K_dim = Hk_g * Dk_g + V_dim = Hv_g * Dv_g + + # Projections: qkvz (in->2K+2V), ba (in->2Hv), out (V->in) + params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim)) + (emb_dim * 2 * Hv_g) + (V_dim * emb_dim) + # Conv: depthwise on 2K+V + params_gdn_conv = (2 * K_dim + V_dim) * K_conv + + params_gdn_layer = params_gdn_proj + params_gdn_conv + + # Total Active Params + # 12 Full Layers, 36 Linear Layers + num_full = N // kwargs["inhomogeneous_layer_cycle_interval"] + num_linear = N - num_full + + total_active_params = ( + (vocab * emb_dim) + + (num_full * (params_full_attn + params_moe_layer)) + + (num_linear * (params_gdn_layer + params_moe_layer)) + ) + + # Weight TFLOPs = 6 * B * S * P + B = kwargs["per_device_batch_size"] + S = kwargs["max_target_length"] + weight_tflops = 6 * B * S * total_active_params / 1e12 + + golden_tflops = weight_tflops + attention_tflops + + # Run Calculation + cfg = pyconfig.initialize( + [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **kwargs, + ) + calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) + + self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) + @pytest.mark.cpu_only def test_llama2_7b_flops(self): """Test Llama2 7b Flops calculation with default parameters"""