Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/models/qwen3-next-80b-a3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ normalization_layer_epsilon: 1.0e-6
base_mlp_dim: 512
base_moe_mlp_dim: 512
num_experts: 512
shared_experts: 1
num_experts_per_tok: 10
norm_topk_prob: True

Expand Down
92 changes: 90 additions & 2 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,82 @@ def get_dense_moe_layers(config):
elif config.decoder_block == DecoderBlockType.LLAMA4:
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
num_dense_layers = config.num_decoder_layers - num_moe_layers
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
num_moe_layers = config.num_decoder_layers
num_dense_layers = 0
else:
raise ValueError("Currently we only support DeepSeek and Llama4 calculation.")
raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.")

return num_dense_layers, num_moe_layers


def calculate_gated_delta_net_flops_per_device(config):
"""Calculates the FLOPs for a single Gated Delta Net (Linear Attention) layer."""
B = config.per_device_batch_size
S = config.max_target_length
E = config.emb_dim

H_k = config.gdn_num_key_heads
H_v = config.gdn_num_value_heads
D_k = config.gdn_key_head_dim
D_v = config.gdn_value_head_dim
C = config.gdn_chunk_size
K_conv = config.gdn_conv_kernel_dim

K_dim = H_k * D_k
V_dim = H_v * D_v

# 1. Projections (Learnable Weights)
# in_proj_qkvz: E -> 2*K_dim + 2*V_dim
flops_qkvz = 2 * B * S * E * (2 * K_dim + 2 * V_dim)
# in_proj_ba: E -> 2*H_v
flops_ba = 2 * B * S * E * (2 * H_v)
# out_proj: V_dim -> E
flops_out = 2 * B * S * V_dim * E

flops_projections = flops_qkvz + flops_ba + flops_out

# 2. Convolution (Learnable Weights)
# Depthwise conv on dim (2*K_dim + V_dim)
# 2 * B * S * Channels * Kernel
flops_conv = 2 * B * S * (2 * K_dim + V_dim) * K_conv

# 3. Core Gated Delta Net (Attention-like operations)
# Assumptions:
# H = H_v (broadcasting K to V heads if H_v > H_k)
# N = num_chunks & N * C ~ S
#
# Query (Q): [B, S, H_v, D_k]
# Keys (K): [B, S, H_v, D_k]
# Values (V): [B, S, H_v, D_v]
# Intra-Chunk Attention (A): [B, N, H_v, C, C]
# Recurrent State (S): [B, N, H_v, D_k, D_v]

# - Intra-chunk terms (per chunk C):
# - attn (K*K): 2 * B * S * H_v * C * D_k
# - val_intra (A*V): 2 * B * S * H_v * C * D_v
# - k_cum (A*K): 2 * B * S * H_v * C * D_k
# - inner_attn_body loop (iterative refinement): ≈ (C - 1) * B * H * N * C^2 ≈ B * H * S * C^2
flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2)

# - Inter-chunk terms (Recurrent State D_k * D_v):
# - attn_i (Q*K): 2 * B * S * H_v * C * D_k
# - v_prime (K*S): 2 * B * S * H_v * D_k * D_v
# - attn_inter (Q*S): 2 * B * S * H_v * D_k * D_v
# - core_out (A*V): 2 * B * S * H_v * C * D_v
# - update (K*V): 2 * B * S * H_v * D_k * D_v
flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v)

flops_core = flops_intra + flops_inter

# Weights part: Projections + Conv
gdn_weight_flops = flops_projections + flops_conv
# Attention part: Core
gdn_attn_flops = flops_core

return gdn_weight_flops, gdn_attn_flops


def calculate_gemma3_vision_layers_tflops_per_device(config):
"""
Estimate TFLOPs for Gemma3 vision encoder (ViT-style).
Expand Down Expand Up @@ -634,7 +704,7 @@ def calculate_tflops_training_per_device(config, log=True):
# MLP flops
if config.num_experts > 1:
# calculation based on dropless implementation
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4):
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT):
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
else:
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
Expand Down Expand Up @@ -702,6 +772,24 @@ def calculate_tflops_training_per_device(config, log=True):
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
gdn_weight_flops_per_layer, gdn_attn_flops_per_layer = calculate_gated_delta_net_flops_per_device(config)
cycle_interval = config.inhomogeneous_layer_cycle_interval
num_full_attn_layers = config.num_decoder_layers // cycle_interval
num_linear_attn_layers = config.num_decoder_layers - num_full_attn_layers

# Weights TFLOPs:
total_weights = (
total_ffn_flops
+ embedding_flops
+ (qkv_flops + projection_flops) * num_full_attn_layers
+ gdn_weight_flops_per_layer * num_linear_attn_layers
)
learnable_weight_tflops = total_weights * 3 / 10**12

# Attention TFLOPs:
total_attn = (causal_attention_flops * num_full_attn_layers) + (gdn_attn_flops_per_layer * num_linear_attn_layers)
attention_tflops = total_attn * 3 / 10**12
else:
# multiply by 3 for both feed forward and back propagation flops
learnable_weight_tflops = (
Expand Down
144 changes: 144 additions & 0 deletions tests/unit/flop_calculation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,150 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float:

return attention_flops / 1e12 # return tflops

def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float:
"""
Computes the total training TFLOPs per device for a Qwen3-Next model.
Only counts the attention mechanism operations (non-weights).
"""
B = kwargs["per_device_batch_size"]
S = kwargs["max_target_length"]
N = kwargs["base_num_decoder_layers"]
cycle_interval = kwargs["inhomogeneous_layer_cycle_interval"]

# Layer counts
num_full_layers = N // cycle_interval
num_linear_layers = N - num_full_layers

# 1. Full Attention FLOPs (Causal)
D_head = kwargs["head_dim"]
H_q = kwargs["base_num_query_heads"]
# 2 for QK^T and SV, 3 for fwd+bwd.
# Note: maxtext_utils divides by 2 for causal masking.
# Formula: 2 * 3 * B * S^2 * H * D
full_attn_flops = 2 * 3 * num_full_layers * B * (S**2) * H_q * D_head

# 2. Linear Attention (Gated Delta Net) FLOPs
H_v = kwargs["gdn_num_value_heads"]
D_k = kwargs["gdn_key_head_dim"]
D_v = kwargs["gdn_value_head_dim"]
C = kwargs["gdn_chunk_size"]

# Formulas from maxtext_utils.calculate_gated_delta_net_flops_per_device
flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2)
flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v)

# 3 for fwd+bwd
linear_attn_flops = 3 * num_linear_layers * (flops_intra + flops_inter)

return (full_attn_flops + linear_attn_flops) / 1e12

@pytest.mark.cpu_only
def test_qwen3_next_flops(self):
"""Test Qwen3-Next Flops calculation"""
kwargs = {
"model_name": "qwen3-next-80b-a3b",
"override_model_config": True,
"per_device_batch_size": 1,
"max_target_length": 4096,
"decoder_block": "qwen3_next",
"gradient_accumulation_steps": 1,
"skip_jax_distributed_system": True,
# Core Architectural Parameters
"base_emb_dim": 2048,
"base_num_decoder_layers": 48,
"base_num_query_heads": 16,
"base_num_kv_heads": 2,
"head_dim": 256,
"vocab_size": 151936,
# MoE Parameters
"base_mlp_dim": 512, # Note: maxtext_utils uses moe_mlp_dim for calculations
"base_moe_mlp_dim": 512,
"num_experts": 512,
"num_experts_per_tok": 10,
"mlp_activations": ["silu", "linear"],
# Qwen3-Next Specific Parameters
"inhomogeneous_layer_cycle_interval": 4,
"gdn_conv_kernel_dim": 4,
"gdn_key_head_dim": 128,
"gdn_value_head_dim": 128,
"gdn_num_key_heads": 16,
"gdn_num_value_heads": 32,
"gdn_chunk_size": 64,
}

# 1. Calculate Attention TFLOPs
attention_tflops = self.compute_qwen3_next_attention_flops_per_device(kwargs)

# 2. Calculate Learnable Weight Active Params
# Config Shortcuts
emb_dim = kwargs["base_emb_dim"]
vocab = kwargs["vocab_size"]
N = kwargs["base_num_decoder_layers"]

# MoE Active Params (per layer)
# FFN uses SwiGLU (3 matrices), Qwen3-Next has 1 shared + N routed experts
# Params = Gate + Shared + Routed
# Gate: emb_dim * num_experts
# Expert: 3 * emb_dim * moe_mlp_dim
moe_mlp_dim = kwargs["base_moe_mlp_dim"]
num_experts = kwargs["num_experts"]
num_routed = kwargs["num_experts_per_tok"]

params_moe_layer = (
(emb_dim * num_experts) + (3 * emb_dim * moe_mlp_dim * 1) + (3 * emb_dim * moe_mlp_dim * num_routed)
)

# Full Attention Params (per full layer)
Hq = kwargs["base_num_query_heads"]
Hkv = kwargs["base_num_kv_heads"]
Hd = kwargs["head_dim"]
# Q, K, V, Out projections
params_full_attn = (emb_dim * (Hq + 2 * Hkv) * Hd) + (Hq * Hd * emb_dim)

# GDN Linear Attention Params (per linear layer)
Hk_g = kwargs["gdn_num_key_heads"]
Hv_g = kwargs["gdn_num_value_heads"]
Dk_g = kwargs["gdn_key_head_dim"]
Dv_g = kwargs["gdn_value_head_dim"]
K_conv = kwargs["gdn_conv_kernel_dim"]

K_dim = Hk_g * Dk_g
V_dim = Hv_g * Dv_g

# Projections: qkvz (in->2K+2V), ba (in->2Hv), out (V->in)
params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim)) + (emb_dim * 2 * Hv_g) + (V_dim * emb_dim)
# Conv: depthwise on 2K+V
params_gdn_conv = (2 * K_dim + V_dim) * K_conv

params_gdn_layer = params_gdn_proj + params_gdn_conv

# Total Active Params
# 12 Full Layers, 36 Linear Layers
num_full = N // kwargs["inhomogeneous_layer_cycle_interval"]
num_linear = N - num_full

total_active_params = (
(vocab * emb_dim)
+ (num_full * (params_full_attn + params_moe_layer))
+ (num_linear * (params_gdn_layer + params_moe_layer))
)

# Weight TFLOPs = 6 * B * S * P
B = kwargs["per_device_batch_size"]
S = kwargs["max_target_length"]
weight_tflops = 6 * B * S * total_active_params / 1e12

golden_tflops = weight_tflops + attention_tflops

# Run Calculation
cfg = pyconfig.initialize(
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
**kwargs,
)
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)

self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)

@pytest.mark.cpu_only
def test_llama2_7b_flops(self):
"""Test Llama2 7b Flops calculation with default parameters"""
Expand Down
Loading