Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __call__(
2. K = RoPE(Norm(Wk @ X))
3. Logits = ReLU(Q @ K.T) # Pairwise similarity
4. Head_Weights = (W_proj @ X) * scale # Dynamic head importance, scale for stability
5. Score = Sum_head(Logits * Head_Weights) # Aggregate heads
5. Score = Logits @ Head_Weights # Aggregate heads
6. Indices = ArgTopk(Score)

Args:
Expand Down Expand Up @@ -281,7 +281,7 @@ def __call__(
# Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability.
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
# Weighted sum over head: sum_h(logits * weights)
# Aggregate head-wise logits: logits @ weights
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]

# Apply attention mask before TopK
Expand Down
122 changes: 110 additions & 12 deletions src/MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,86 @@ def calculate_llama4_attention_tflops(config):
return attention_tflops


def calculate_indexer_mask_ratio(index_topk, max_target_length):
"""
Calculates the sparse-to-dense ratio for Indexer TFLOPs.

The indexer evaluates all previous tokens in a causal manner until it hits
the Top-K limit.

Visual Representation (T=8, K=4):
Key (S) ->
Q1 [X . . . . . . .] <- 1 token scored
Q2 [X X . . . . . .] <- 2 tokens scored
Q3 [X X X . . . . .] <- 3 tokens scored
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
Q5 [X X X . X . . .] <- 4 tokens scored
Q6 [X X . X . X . .] <- 4 tokens scored
Q7 [X . X X . . X .] <- 4 tokens scored
Q8 [X X . X . . . X] <- 4 tokens scored

For MFU calculation:

Visual Representation (T=8, K=4):
Key (S) ->
Q1 [X . . . . . . .] <- 1 token scored
Q2 [X X . . . . . .] <- 2 tokens scored
Q3 [X X X . . . . .] <- 3 tokens scored
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
Q5 [X X X X . . . .] <- 4 tokens scored
Q6 [X X X X . . . .] <- 4 tokens scored
Q7 [X X X X . . . .] <- 4 tokens scored
Q8 [X X X X . . . .] <- 4 tokens scored

Mathematical Calculation:
- Triangle (Phase 1: 1 to K): K^2 / 2
- Rectangle (Phase 2: K+1 to T): (T - K) * K
- Total Active Area = TK - K^2 / 2
- Dense Area = T^2

Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2
"""

T = float(max_target_length)
K = float(index_topk)

ratio = K / T
mask_multiplier = ratio - (0.5 * ratio**2)
return mask_multiplier


def calculate_indexer_tflops_per_device(config):
"""Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction)."""
batch_len = config.per_device_batch_size * config.max_target_length

# 1. Calculate projections flops
# Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim]
q_flops = 2 * batch_len * config.q_lora_rank * config.index_n_heads * config.index_head_dim
# Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim]
k_flops = 2 * batch_len * config.emb_dim * config.index_head_dim
# Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads]
head_weight_flops = 2 * batch_len * config.emb_dim * config.index_n_heads
proj_flops = q_flops + k_flops + head_weight_flops

# 2. Calculate index score flops
# QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim]
# --> [batch, seq, seq, index_n_heads]
qk_product_flops = 2 * batch_len * config.max_target_length * config.index_n_heads * config.index_head_dim
# Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads]
head_reduction_flops = 2 * batch_len * config.max_target_length * config.index_n_heads
# Apply causal mask: Divide by 2 to account for triangular interactions
# The mask restricts the indexer's search space prior to Top-K filtering
scoring_flops = (qk_product_flops + head_reduction_flops) / 2

return proj_flops, scoring_flops


def calculate_mla_tflops_per_device(config):
"""Calculate Multi-Head Latent Attention TFLOP"""
"""Calculate Multi-Head Latent Attention TFLOP (handles causal reduction)"""
batch_len = config.per_device_batch_size * config.max_target_length
qk_head_dim_sum = config.qk_nope_head_dim + config.qk_rope_head_dim
# calculate mla query projection

# 1. calculate mla query projection
if config.q_lora_rank == 0:
q_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * qk_head_dim_sum
else:
Expand All @@ -333,7 +408,8 @@ def calculate_mla_tflops_per_device(config):
* batch_len
* (config.emb_dim * config.q_lora_rank + config.q_lora_rank * config.num_query_heads * qk_head_dim_sum)
)
# calculate mla kv projection with down and up flops

# 2. calculate mla kv projection
kv_flops = (
2
* batch_len
Expand All @@ -344,9 +420,31 @@ def calculate_mla_tflops_per_device(config):
)
qkv_flops = q_flops + kv_flops

attention_flops = (
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
)
# 3. calculate attention
if config.use_sparse_indexer and config.max_target_length > config.index_topk:
# get indexer flops
indexer_proj_flops, indexer_scoring_flops = calculate_indexer_tflops_per_device(config)
qkv_flops += indexer_proj_flops

# calculate the proportion of the T x T causal matrix that the Indexer actually explores
# this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk)
multiplier = calculate_indexer_mask_ratio(config.index_topk, config.max_target_length)
attention_flops = (
2
* batch_len
* config.max_target_length
* config.num_query_heads
* (qk_head_dim_sum + config.v_head_dim)
* multiplier
)
attention_flops += indexer_scoring_flops
else:
# standard MLA & max_target_length <= index_topk in sparse indexer
# in both cases, the indexer is bypassed as the causal mask remains the efficient representation
attention_flops = (
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
)
attention_flops = attention_flops / 2
projection_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * config.v_head_dim
return qkv_flops, attention_flops, projection_flops

Expand Down Expand Up @@ -546,7 +644,7 @@ def calculate_tflops_training_per_device(config, log=True):

# Attention flops
if config.attention_type == "mla":
qkv_flops, noncausal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
else:
qkv_flops = (
2
Expand All @@ -568,11 +666,11 @@ def calculate_tflops_training_per_device(config, log=True):
* config.head_dim
)

# Divide attention flops by 2 due to causal mask
# References:
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
causal_attention_flops = noncausal_attention_flops / 2
# Divide attention flops by 2 due to causal mask
# References:
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
causal_attention_flops = noncausal_attention_flops / 2

# Embedding flops
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/flop_calculation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,55 @@ def test_gpt_oss_20b_flops(self):
)
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)

@pytest.mark.cpu_only
def test_deepseek32_671b_flops(self):
"""Test DeepSeek3.2-671b FLops calculation"""
kwargs = {
# Model bases
"model_name": "deepseek3.2-671b",
"override_model_config": True,
# Core workload parameters
"per_device_batch_size": 4,
"max_target_length": 4096,
"num_experts": 256,
"num_experts_per_tok": 8,
"shared_experts": 1,
# Model dimensions
"base_emb_dim": 7168,
"base_num_query_heads": 128,
"base_num_kv_heads": 128,
"base_mlp_dim": 18432,
"base_moe_mlp_dim": 2048,
"base_num_decoder_layers": 61,
"first_num_dense_layers": 3,
"mlp_activations": ["silu", "linear"],
"vocab_size": 129280,
# MLA
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"skip_jax_distributed_system": True,
# Indexer for DeepSeek Sparse Attention
"use_sparse_indexer": True,
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 2048,
# TODO(ranran): remove after flash attention is supported
"attention": "dot_product",
}
B = kwargs["per_device_batch_size"]
S = kwargs["max_target_length"]
attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs)
# deepseek3-671b has ~37B active parameters
# https://arxiv.org/pdf/2412.19437
golden_param_size = 37e9
golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops
cfg = pyconfig.initialize(
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
**kwargs,
)
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
Loading