diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 819f430983..d86a9839a0 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -226,7 +226,7 @@ def __call__( 2. K = RoPE(Norm(Wk @ X)) 3. Logits = ReLU(Q @ K.T) # Pairwise similarity 4. Head_Weights = (W_proj @ X) * scale # Dynamic head importance, scale for stability - 5. Score = Sum_head(Logits * Head_Weights) # Aggregate heads + 5. Score = Logits @ Head_Weights # Aggregate heads 6. Indices = ArgTopk(Score) Args: @@ -281,7 +281,7 @@ def __call__( # Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability. # https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480 weights = weights * (self.n_heads**-0.5) * self.softmax_scale - # Weighted sum over head: sum_h(logits * weights) + # Aggregate head-wise logits: logits @ weights index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s] # Apply attention mask before TopK diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 64138db751..9eb98be95d 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -319,11 +319,86 @@ def calculate_llama4_attention_tflops(config): return attention_tflops +def calculate_indexer_mask_ratio(index_topk, max_target_length): + """ + Calculates the sparse-to-dense ratio for Indexer TFLOPs. + + The indexer evaluates all previous tokens in a causal manner until it hits + the Top-K limit. + + Visual Representation (T=8, K=4): + Key (S) -> + Q1 [X . . . . . . .] <- 1 token scored + Q2 [X X . . . . . .] <- 2 tokens scored + Q3 [X X X . . . . .] <- 3 tokens scored + Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached) + Q5 [X X X . X . . .] <- 4 tokens scored + Q6 [X X . X . X . .] <- 4 tokens scored + Q7 [X . X X . . X .] <- 4 tokens scored + Q8 [X X . X . . . X] <- 4 tokens scored + + For MFU calculation: + + Visual Representation (T=8, K=4): + Key (S) -> + Q1 [X . . . . . . .] <- 1 token scored + Q2 [X X . . . . . .] <- 2 tokens scored + Q3 [X X X . . . . .] <- 3 tokens scored + Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached) + Q5 [X X X X . . . .] <- 4 tokens scored + Q6 [X X X X . . . .] <- 4 tokens scored + Q7 [X X X X . . . .] <- 4 tokens scored + Q8 [X X X X . . . .] <- 4 tokens scored + + Mathematical Calculation: + - Triangle (Phase 1: 1 to K): K^2 / 2 + - Rectangle (Phase 2: K+1 to T): (T - K) * K + - Total Active Area = TK - K^2 / 2 + - Dense Area = T^2 + + Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2 + """ + + T = float(max_target_length) + K = float(index_topk) + + ratio = K / T + mask_multiplier = ratio - (0.5 * ratio**2) + return mask_multiplier + + +def calculate_indexer_tflops_per_device(config): + """Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction).""" + batch_len = config.per_device_batch_size * config.max_target_length + + # 1. Calculate projections flops + # Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim] + q_flops = 2 * batch_len * config.q_lora_rank * config.index_n_heads * config.index_head_dim + # Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim] + k_flops = 2 * batch_len * config.emb_dim * config.index_head_dim + # Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads] + head_weight_flops = 2 * batch_len * config.emb_dim * config.index_n_heads + proj_flops = q_flops + k_flops + head_weight_flops + + # 2. Calculate index score flops + # QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim] + # --> [batch, seq, seq, index_n_heads] + qk_product_flops = 2 * batch_len * config.max_target_length * config.index_n_heads * config.index_head_dim + # Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads] + head_reduction_flops = 2 * batch_len * config.max_target_length * config.index_n_heads + # Apply causal mask: Divide by 2 to account for triangular interactions + # The mask restricts the indexer's search space prior to Top-K filtering + scoring_flops = (qk_product_flops + head_reduction_flops) / 2 + + return proj_flops, scoring_flops + + def calculate_mla_tflops_per_device(config): - """Calculate Multi-Head Latent Attention TFLOP""" + """Calculate Multi-Head Latent Attention TFLOP (handles causal reduction)""" batch_len = config.per_device_batch_size * config.max_target_length qk_head_dim_sum = config.qk_nope_head_dim + config.qk_rope_head_dim - # calculate mla query projection + + # 1. calculate mla query projection if config.q_lora_rank == 0: q_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * qk_head_dim_sum else: @@ -333,7 +408,8 @@ def calculate_mla_tflops_per_device(config): * batch_len * (config.emb_dim * config.q_lora_rank + config.q_lora_rank * config.num_query_heads * qk_head_dim_sum) ) - # calculate mla kv projection with down and up flops + + # 2. calculate mla kv projection kv_flops = ( 2 * batch_len @@ -344,9 +420,31 @@ def calculate_mla_tflops_per_device(config): ) qkv_flops = q_flops + kv_flops - attention_flops = ( - 2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim) - ) + # 3. calculate attention + if config.use_sparse_indexer and config.max_target_length > config.index_topk: + # get indexer flops + indexer_proj_flops, indexer_scoring_flops = calculate_indexer_tflops_per_device(config) + qkv_flops += indexer_proj_flops + + # calculate the proportion of the T x T causal matrix that the Indexer actually explores + # this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk) + multiplier = calculate_indexer_mask_ratio(config.index_topk, config.max_target_length) + attention_flops = ( + 2 + * batch_len + * config.max_target_length + * config.num_query_heads + * (qk_head_dim_sum + config.v_head_dim) + * multiplier + ) + attention_flops += indexer_scoring_flops + else: + # standard MLA & max_target_length <= index_topk in sparse indexer + # in both cases, the indexer is bypassed as the causal mask remains the efficient representation + attention_flops = ( + 2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim) + ) + attention_flops = attention_flops / 2 projection_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * config.v_head_dim return qkv_flops, attention_flops, projection_flops @@ -546,7 +644,7 @@ def calculate_tflops_training_per_device(config, log=True): # Attention flops if config.attention_type == "mla": - qkv_flops, noncausal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config) + qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config) else: qkv_flops = ( 2 @@ -568,11 +666,11 @@ def calculate_tflops_training_per_device(config, log=True): * config.head_dim ) - # Divide attention flops by 2 due to causal mask - # References: - # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362 - # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272 - causal_attention_flops = noncausal_attention_flops / 2 + # Divide attention flops by 2 due to causal mask + # References: + # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362 + # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272 + causal_attention_flops = noncausal_attention_flops / 2 # Embedding flops embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size diff --git a/tests/unit/flop_calculation_test.py b/tests/unit/flop_calculation_test.py index 5c2ff253e0..76069ed74b 100644 --- a/tests/unit/flop_calculation_test.py +++ b/tests/unit/flop_calculation_test.py @@ -292,3 +292,55 @@ def test_gpt_oss_20b_flops(self): ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) + + @pytest.mark.cpu_only + def test_deepseek32_671b_flops(self): + """Test DeepSeek3.2-671b FLops calculation""" + kwargs = { + # Model bases + "model_name": "deepseek3.2-671b", + "override_model_config": True, + # Core workload parameters + "per_device_batch_size": 4, + "max_target_length": 4096, + "num_experts": 256, + "num_experts_per_tok": 8, + "shared_experts": 1, + # Model dimensions + "base_emb_dim": 7168, + "base_num_query_heads": 128, + "base_num_kv_heads": 128, + "base_mlp_dim": 18432, + "base_moe_mlp_dim": 2048, + "base_num_decoder_layers": 61, + "first_num_dense_layers": 3, + "mlp_activations": ["silu", "linear"], + "vocab_size": 129280, + # MLA + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "skip_jax_distributed_system": True, + # Indexer for DeepSeek Sparse Attention + "use_sparse_indexer": True, + "index_n_heads": 64, + "index_head_dim": 128, + "index_topk": 2048, + # TODO(ranran): remove after flash attention is supported + "attention": "dot_product", + } + B = kwargs["per_device_batch_size"] + S = kwargs["max_target_length"] + attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs) + # deepseek3-671b has ~37B active parameters + # https://arxiv.org/pdf/2412.19437 + golden_param_size = 37e9 + golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops + cfg = pyconfig.initialize( + [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **kwargs, + ) + calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) + self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)