-
Notifications
You must be signed in to change notification settings - Fork 511
Expand file tree
/
Copy pathmaxtext_utils.py
More file actions
1334 lines (1109 loc) · 55.4 KB
/
maxtext_utils.py
File metadata and controls
1334 lines (1109 loc) · 55.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
""" Utils that are only interesting to MaxText. """
import functools
import pickle
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
import numpy as np
from jax.experimental import mesh_utils
from jax.experimental.serialize_executable import deserialize_and_load
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
from MaxText import checkpointing
from MaxText import max_logging
from MaxText import max_utils
from MaxText import multimodal_utils
from MaxText import sharding
from MaxText.configs import types
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
from MaxText.inference.page_manager import PageState
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
def get_input_data_sharding(config, mesh):
max_logging.log(
"WARNING: Function maxtext_utils.get_input_data_sharding is deprecated. Please use sharding.get_input_data_sharding."
)
return sharding.get_input_data_sharding(config, mesh)
def assert_params_sufficiently_sharded(params, mesh, tolerance):
max_logging.log(
"WARNING: Function maxtext_utils.assert_params_sufficiently_sharded is deprecated."
"Please use sharding.assert_params_sufficiently_sharded."
)
return sharding.assert_params_sufficiently_sharded(params, mesh, tolerance)
def add_data_to_sharding(mesh, path, aval, shardings):
max_logging.log(
"WARNING: Function maxtext_utils.add_data_to_sharding is deprecated. Please use sharding.add_data_to_sharding."
)
return sharding.add_data_to_sharding(mesh, path, aval, shardings)
def maybe_update_params_sharding_with_opt(config, state_mesh_shardings):
max_logging.log(
"WARNING: Function maxtext_utils.maybe_update_params_sharding_with_opt is deprecated."
"Please use sharding.maybe_update_params_sharding_with_opt."
)
return sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, shard_mode):
max_logging.log(
"WARNING: Function maxtext_utils.all_gather_over_fsdp is deprecated. Please use sharding.all_gather_over_fsdp."
)
return sharding.all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, shard_mode)
def get_functional_train_with_signature(
train_step, data_sharding, state_mesh_shardings, model, config, params_shardings=None
):
"""Get the shardings (both state and data) for `train_step`."""
functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
functional_train.__name__ = "train_step"
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = (state_mesh_shardings, None) # State, metrics
static_argnums = () # We partial out the static argnums of model and config
donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory.
return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums
def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shardings, model, config):
"""Get the shardings (both state and data) for `eval_step`."""
functional_eval = functools.partial(eval_step, model, config)
functional_eval.__name__ = "eval_step"
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = None # metrics
static_argnums = () # We partial out the static argnums of model, config
donate_argnums = () # state will be kept instead of being donated in eval_step
return functional_eval, in_shardings, out_shardings, static_argnums, donate_argnums
def shard_reorder_causal_load_balanced(batch, cp_size, shard_mode):
"""Shard the output of the reordered sequence."""
reordered = max_utils.reorder_causal_load_balanced(batch, cp_size)
for _, v in batch.items():
if isinstance(v, jax.Array):
reordered = sharding.maybe_shard_with_name(reordered, v.sharding, shard_mode)
break
return reordered
def get_reorder_callable(cp_size, shard_mode):
"""Creates a callable that can be used with map() to reorder batches."""
return functools.partial(shard_reorder_causal_load_balanced, cp_size=cp_size, shard_mode=shard_mode)
def get_shaped_batch(config):
"""Return the shape of the batch - this is what eval_shape would return for the
output of create_data_iterator, but eval_shape doesn't work, see b/306901078."""
batch_shape = (config.global_batch_size_to_load, config.max_target_length)
shaped_batch = {}
shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
if config.use_multimodal:
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
config.model_name, batch_size=config.micro_batch_size_to_train_on
)
shaped_batch["images"] = jax.ShapeDtypeStruct(image_shape, jnp.int32)
shaped_batch["image_masks"] = jax.ShapeDtypeStruct(image_shape[:2], jnp.int32)
if config.use_audio:
audio_shape = multimodal_utils.get_dummy_audio_shape_for_init(
config.model_name, config=config, batch_size=config.micro_batch_size_to_train_on
)
shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_shape, jnp.float32)
return shaped_batch
def should_prevent_cse_in_remat(config):
"""Determines whether to prevent common subexpression elimination (CSE) in remat.
CSE should not be prevented when:
1. Layers are being scanned (scan_layers=True), OR
2. Gradient accumulation is enabled (gradient_accumulation_steps > 1) on GPU hardware
Args:
config: Configuration object with scan_layers, gradient_accumulation_steps, and hardware
Returns:
bool: True if CSE should be prevented, False otherwise
"""
if config.scan_layers:
return False
if config.gradient_accumulation_steps > 1 and config.hardware in ("gpu", "gpu_multiprocess"):
return False
return True
def load_compiled(config, partial_train, state, execution_devices):
"""# Loading a serialized compiled train step function."""
# Currently partial_train and state are needed to reconstruct
# input/output shapes to construct the in_trees and out_trees for load API
# Parker is working on a serializing these
def load_serialized_compiled(save_name):
with open(save_name, "rb") as f:
serialized_compiled = pickle.load(f)
return serialized_compiled
def get_train_input_output_trees(func, input_args, input_kwargs):
_, in_tree_recreated = jax.tree_util.tree_flatten((input_args, input_kwargs))
out_shaped = jax.eval_shape(func, *input_args, **input_kwargs)
_, out_tree_recreated = jax.tree_util.tree_flatten(out_shaped)
return in_tree_recreated, out_tree_recreated
serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file)
shaped_batch = get_shaped_batch(config)
example_rng = jax.random.PRNGKey(0)
shaped_input_args = (state, shaped_batch, example_rng)
shaped_input_kwargs = {}
in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs)
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree, execution_devices=execution_devices)
return p_train_step
def calculate_tokens_training_per_device(config):
"""Calculate training Tokens per device"""
return config.max_target_length * config.per_device_batch_size * config.gradient_accumulation_steps
def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops):
"""
Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder
layer and we use sliding window attention in local_attention
"""
noncausal_attention_flops = (
# global attention
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
+
# local attention
4
* config.per_device_batch_size
* config.max_target_length
* min(config.sliding_window_size, config.max_target_length)
* config.num_query_heads
* config.head_dim
)
causal_attention_flops = noncausal_attention_flops / 2
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
# multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
learnable_weight_tflops = (
((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers * 2 + embedding_flops) * 3 / 10**12
)
return attention_tflops, learnable_weight_tflops
def calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length
):
"""
Calculate training TFLOPs for models with a mixed attention pattern of local
and global attention layers, like Gemma3 and GPT-OSS.
"""
num_layers = config.num_decoder_layers
num_global_layers = num_layers // attention_pattern_length
num_local_layers = num_layers - num_global_layers
# FLOPs for a single global attention layer (full attention)
# Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
global_attention_flops_per_layer = (
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
)
# FLOPs for a single local attention layer (sliding window)
# Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
local_attention_flops_per_layer = (
4
* config.per_device_batch_size
* config.max_target_length
* min(config.sliding_window_size, config.max_target_length)
* config.num_query_heads
* config.head_dim
)
# Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
noncausal_attention_flops = (
num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer
)
causal_attention_flops = noncausal_attention_flops / 2
# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
attention_tflops = causal_attention_flops * 3 / 10**12
# Learnable weights (FFN, QKV, Projections) are present in every layer.
learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12
return attention_tflops, learnable_weight_tflops
def _calculate_chunked_attention_flops_per_layer(config, seq_len, chunk_size):
"""Calculates the non-causal FLOPs for a single layer of chunked attention."""
num_chunks = seq_len // chunk_size
rem_chunk_size = seq_len % chunk_size
# The complexity of chunked attention is the sum of squares of chunk lengths.
chunked_complexity = (num_chunks * chunk_size**2) + (rem_chunk_size**2)
# The formula for non-causal attention FLOPs is 4 * B * complexity * H * D,
# where B=batch_size, H=num_heads, D=head_dim.
return 4 * config.per_device_batch_size * chunked_complexity * config.num_query_heads * config.head_dim
def calculate_llama4_attention_tflops(config):
"""
Calculates attention-only training TFLOPs for Llama4's specific architecture,
which has an alternating pattern of global and chunked attention layers.
"""
num_layers = config.num_decoder_layers
seq_len = config.max_target_length
chunk_size = config.chunk_attn_window_size
# Determine number of global vs. chunked layers based on the NoPE interval.
# A "NoPE" layer uses global attention.
num_global_layers = num_layers // config.nope_layer_interval
num_chunked_layers = num_layers - num_global_layers
# FLOPs for a single global attention layer (full attention, non-causal)
global_attention_flops_per_layer = (
4 * config.per_device_batch_size * seq_len**2 * config.num_query_heads * config.head_dim
)
# FLOPs for a single chunked attention layer (non-causal)
chunked_attention_flops_per_layer = _calculate_chunked_attention_flops_per_layer(config, seq_len, chunk_size)
# Total non-causal attention FLOPs is the sum of all global and all chunked layers
noncausal_attention_flops = (num_global_layers * global_attention_flops_per_layer) + (
num_chunked_layers * chunked_attention_flops_per_layer
)
# Apply causal mask and convert to TFLOPs (multiply by 3 for fwd/bwd pass)
causal_attention_flops = noncausal_attention_flops / 2
attention_tflops = causal_attention_flops * 3 / 10**12
return attention_tflops
def calculate_indexer_mask_ratio(index_topk, max_target_length):
"""
Calculates the sparse-to-dense ratio for Indexer TFLOPs.
The indexer evaluates all previous tokens in a causal manner until it hits
the Top-K limit.
Visual Representation (T=8, K=4):
Key (S) ->
Q1 [X . . . . . . .] <- 1 token scored
Q2 [X X . . . . . .] <- 2 tokens scored
Q3 [X X X . . . . .] <- 3 tokens scored
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
Q5 [X X X . X . . .] <- 4 tokens scored
Q6 [X X . X . X . .] <- 4 tokens scored
Q7 [X . X X . . X .] <- 4 tokens scored
Q8 [X X . X . . . X] <- 4 tokens scored
For MFU calculation:
Visual Representation (T=8, K=4):
Key (S) ->
Q1 [X . . . . . . .] <- 1 token scored
Q2 [X X . . . . . .] <- 2 tokens scored
Q3 [X X X . . . . .] <- 3 tokens scored
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
Q5 [X X X X . . . .] <- 4 tokens scored
Q6 [X X X X . . . .] <- 4 tokens scored
Q7 [X X X X . . . .] <- 4 tokens scored
Q8 [X X X X . . . .] <- 4 tokens scored
Mathematical Calculation:
- Triangle (Phase 1: 1 to K): K^2 / 2
- Rectangle (Phase 2: K+1 to T): (T - K) * K
- Total Active Area = TK - K^2 / 2
- Dense Area = T^2
Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2
"""
T = float(max_target_length)
K = float(index_topk)
ratio = K / T
mask_multiplier = ratio - (0.5 * ratio**2)
return mask_multiplier
def calculate_indexer_tflops_per_device(config):
"""Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction)."""
batch_len = config.per_device_batch_size * config.max_target_length
# 1. Calculate projections flops
# Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim]
q_flops = 2 * batch_len * config.q_lora_rank * config.index_n_heads * config.index_head_dim
# Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim]
k_flops = 2 * batch_len * config.emb_dim * config.index_head_dim
# Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads]
head_weight_flops = 2 * batch_len * config.emb_dim * config.index_n_heads
proj_flops = q_flops + k_flops + head_weight_flops
# 2. Calculate index score flops
# QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim]
# --> [batch, seq, seq, index_n_heads]
qk_product_flops = 2 * batch_len * config.max_target_length * config.index_n_heads * config.index_head_dim
# Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads]
head_reduction_flops = 2 * batch_len * config.max_target_length * config.index_n_heads
# Apply causal mask: Divide by 2 to account for triangular interactions
# The mask restricts the indexer's search space prior to Top-K filtering
scoring_flops = (qk_product_flops + head_reduction_flops) / 2
return proj_flops, scoring_flops
def calculate_mla_tflops_per_device(config):
"""Calculate Multi-Head Latent Attention TFLOP (handles causal reduction)"""
batch_len = config.per_device_batch_size * config.max_target_length
qk_head_dim_sum = config.qk_nope_head_dim + config.qk_rope_head_dim
# 1. calculate mla query projection
if config.q_lora_rank == 0:
q_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * qk_head_dim_sum
else:
# calculate query down and up flops
q_flops = (
2
* batch_len
* (config.emb_dim * config.q_lora_rank + config.q_lora_rank * config.num_query_heads * qk_head_dim_sum)
)
# 2. calculate mla kv projection
kv_flops = (
2
* batch_len
* (
config.emb_dim * (config.kv_lora_rank + config.qk_rope_head_dim)
+ config.kv_lora_rank * config.num_query_heads * (config.qk_nope_head_dim + config.v_head_dim)
)
)
qkv_flops = q_flops + kv_flops
# 3. calculate attention
if config.use_sparse_indexer and config.max_target_length > config.index_topk:
# get indexer flops
indexer_proj_flops, indexer_scoring_flops = calculate_indexer_tflops_per_device(config)
qkv_flops += indexer_proj_flops
# calculate the proportion of the T x T causal matrix that the Indexer actually explores
# this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk)
multiplier = calculate_indexer_mask_ratio(config.index_topk, config.max_target_length)
attention_flops = (
2
* batch_len
* config.max_target_length
* config.num_query_heads
* (qk_head_dim_sum + config.v_head_dim)
* multiplier
)
attention_flops += indexer_scoring_flops
else:
# standard MLA & max_target_length <= index_topk in sparse indexer
# in both cases, the indexer is bypassed as the causal mask remains the efficient representation
attention_flops = (
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
)
attention_flops = attention_flops / 2
projection_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * config.v_head_dim
return qkv_flops, attention_flops, projection_flops
def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim):
"""Helper function to calculate matmul TFLOP in ffn based on MLP dimension.
Applies to:
- Dense FFN layers (mlp_dim = config.mlp_dim).
- MoE FFN layers (mlp_dim = config.moe_mlp_dim),
need to scale by shared_experts or num_experts_per_tok.
"""
ffn1_flops = (
2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim * len(config.mlp_activations)
)
ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim
return ffn1_flops + ffn2_flops
def calculate_routed_and_shared_ffn_tflops_per_device(config):
"""Helper function to calculate DeepSeek-style ffn TFLOP"""
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
# Due to the mixed decoder layers, the flops is multiplied by num of layers for both dense and moe
num_dense_layers, num_moe_layers = get_dense_moe_layers(config)
dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * num_dense_layers
shared_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.shared_experts
routed_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops) * num_moe_layers
total_ffn_flops = dense_ffn_flops + moe_ffn_flops
return total_ffn_flops
def get_dense_moe_layers(config):
"""Helper function to calculate number of dense and moe layers"""
if config.decoder_block == DecoderBlockType.DEEPSEEK:
num_dense_layers = config.first_num_dense_layers
num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers
return num_dense_layers, num_moe_layers
elif config.decoder_block == DecoderBlockType.LLAMA4:
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
num_dense_layers = config.num_decoder_layers - num_moe_layers
else:
raise ValueError("Currently we only support DeepSeek and Llama4 calculation.")
return num_dense_layers, num_moe_layers
def calculate_gemma3_vision_layers_tflops_per_device(config):
"""
Estimate TFLOPs for Gemma3 vision encoder (ViT-style).
Returns:
total_tflops: Total TFLOPs (counts for fwd + bwd + optimizer)
learnable_weight_tflops: TFLOPs from learnable weights (patch embedding, qkv, MLP, projections)
attention_tflops: TFLOPs from attention multiplications
"""
# Config values
B = config.per_device_batch_size
C = config.num_channels_for_vit
H = W = config.image_size_for_vit # Gemma3 default 896
embed_dim = config.emb_dim # text embedding dim after projection
# Values below are hardcoded in Gemma3VisionEncoderLayer
patch_size = 14
hidden_dim = 1152
intermediate_dim = 4304
num_layers = 27
vision_exit_pooling_window = 4
# 1. Patch embedding (Conv2D)
num_patches_h = H // patch_size
num_patches_w = W // patch_size
seq_len = num_patches_h * num_patches_w # 64*64=4096
patch_embed_flops = 2 * B * seq_len * (C * patch_size * patch_size) * hidden_dim
# 2. gemma3.Encoder: num_layers * gemma3.Encoder1DBlock
qkv_flops_per_layer = 3 * (2 * B * seq_len * hidden_dim * hidden_dim)
attn_flops_per_layer = 4 * B * seq_len * seq_len * hidden_dim
projection_flops_per_layer = 2 * B * seq_len * hidden_dim * hidden_dim # projection after attention multiplication
mlp_flops_per_layer = 2 * (2 * B * seq_len * hidden_dim * intermediate_dim) # two fc layers
total_attn_flops = attn_flops_per_layer * num_layers
encoder_flops = (+qkv_flops_per_layer + projection_flops_per_layer + mlp_flops_per_layer) * num_layers
# 4. VisionEmbedder
seq_len_after_pooling = (num_patches_h // vision_exit_pooling_window) * (num_patches_w // vision_exit_pooling_window)
vision_embedder_flops = 2 * B * seq_len_after_pooling * hidden_dim * embed_dim # One linear projection
# Learnable weights summation
learnable_weight_flops = patch_embed_flops + encoder_flops + vision_embedder_flops
if config.freeze_vision_encoder_params:
learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer
else:
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
# Convert to TFLOPs
learnable_weight_tflops = learnable_weight_flops / 1e12
total_attn_tflops = total_attn_flops / 1e12
total_tflops = learnable_weight_tflops + total_attn_tflops
return total_tflops, learnable_weight_tflops, total_attn_tflops
def calculate_llama4_vision_layers_tflops_per_device(config):
"""
Estimate TFLOPs for Llama4 vision encoder (ViT-style).
Returns:
total_tflops: Total TFLOPs (counts for fwd + bwd + optimizer)
learnable_weight_tflops: TFLOPs from learnable weights (patch embedding, qkv, MLP, projections)
attention_tflops: TFLOPs from attention multiplications
"""
# Config values
B = config.per_device_batch_size
C = config.num_channels_for_vit
H = W = config.tile_size_for_vit
patch_size = config.patch_size_for_vit
hidden_dim = config.hidden_size_for_vit
intermediate_dim = config.intermediate_size_for_vit
num_layers = config.num_hidden_layers_for_vit
pixel_shuffle_fc1_out_dim = config.projector_input_dim_for_vit # 4096
pixel_shuffle_fc2_out_dim = config.projector_output_dim_for_vit # 4096
base_emb_dim = config.base_emb_dim
pixel_shuffle_ratio = config.pixel_shuffle_ratio_for_vit # 0.5
num_patches = (H // patch_size) * (W // patch_size) # 24*24 = 576
pixel_shuffle_tokens = num_patches * pixel_shuffle_ratio**2 # 144
# 1. Llama4UnfoldConvolution (flops by linear projection)
# lax.conv_general_dilated_patches extracts patches through reshaping/indexing without flops
# Each patch: C * patch_size * patch_size -> hidden_dim
patch_embed_flops = 2 * B * num_patches * (C * patch_size * patch_size) * hidden_dim
# 2. Llama4VisionEncoder: num_layers * (qkv + att_projection + mlp)
seq_len = num_patches + 1 # +1 for class token, so 577
qkv_flops_per_layer = 3 * (2 * B * seq_len * hidden_dim * hidden_dim) # Q, K, V projections
attn_flops_per_layer = 4 * B * seq_len * seq_len * hidden_dim # Attention scores and weighted sum
projection_flops_per_layer = 2 * B * seq_len * hidden_dim * hidden_dim # projection after attention multiplication
mlp_flops_per_layer = 2 * (2 * B * seq_len * hidden_dim * intermediate_dim) # two fc layers
total_attn_flops = attn_flops_per_layer * num_layers
vision_encoder_flops = (+qkv_flops_per_layer + projection_flops_per_layer + mlp_flops_per_layer) * num_layers
# 3. Llama4VisionPixelShuffleMLP
# (B, 144, 5632) -> (B, 144, 4096) -> (B, 144, 4096)
pixel_shuffle_fc1_flops = 2 * B * pixel_shuffle_tokens * intermediate_dim * pixel_shuffle_fc1_out_dim
pixel_shuffle_fc2_flops = 2 * B * pixel_shuffle_tokens * pixel_shuffle_fc1_out_dim * pixel_shuffle_fc2_out_dim
pixel_shuffle_total_flops = pixel_shuffle_fc1_flops + pixel_shuffle_fc2_flops
# 4. Llama4MultiModalProjector: (B, 144, 5120) x (5120, base_emb_dim)
projector_flops = 2 * B * pixel_shuffle_tokens * pixel_shuffle_fc1_out_dim * base_emb_dim
# Learnable weights: all matmuls above
learnable_weight_flops = patch_embed_flops + vision_encoder_flops + pixel_shuffle_total_flops + projector_flops
if config.freeze_vision_encoder_params:
learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer
else:
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
# Convert to TFLOPs
learnable_weight_tflops = learnable_weight_flops / 1e12
total_attn_tflops = total_attn_flops / 1e12
total_tflops = learnable_weight_tflops + total_attn_tflops
return total_tflops, learnable_weight_tflops, total_attn_tflops
def calculate_vision_encoder_tflops(config):
"""Calculate vision encoder TFLOPs per prefill step per device."""
if config.model_name.startswith("gemma3"):
mm_total_tflops, mm_learnable_weight_tflops, mm_attention_tflops = calculate_gemma3_vision_layers_tflops_per_device(
config
)
elif config.model_name.startswith("llama4"):
mm_total_tflops, mm_learnable_weight_tflops, mm_attention_tflops = calculate_llama4_vision_layers_tflops_per_device(
config
)
else:
max_logging.log(
f"Vision encoder TFLOPs calculation not implemented for model {config.model_name}, counting as 0 for now."
)
mm_total_tflops = mm_learnable_weight_tflops = mm_attention_tflops = 0
return mm_total_tflops, mm_learnable_weight_tflops, mm_attention_tflops
def calculate_tflops_training_per_device(config, log=True):
"""Calculate training TFLOP"""
# MLP flops
if config.num_experts > 1:
# calculation based on dropless implementation
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4):
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
else:
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
total_ffn_flops = (
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok
)
else:
total_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim)
# Attention flops
if config.attention_type == "mla":
qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
else:
qkv_flops = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* (config.num_query_heads + 2 * config.num_kv_heads)
* config.head_dim
)
noncausal_attention_flops = (
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
)
projection_flops = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* config.num_query_heads
* config.head_dim
)
# Divide attention flops by 2 due to causal mask
# References:
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
causal_attention_flops = noncausal_attention_flops / 2
# Embedding flops
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size
# Combine flops with number of decoder layers
if config.decoder_block == DecoderBlockType.GEMMA2:
attention_tflops, learnable_weight_tflops = calculate_gemma2_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
)
elif config.decoder_block == DecoderBlockType.GEMMA3:
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
)
elif config.decoder_block == DecoderBlockType.GPT_OSS:
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2
)
elif config.decoder_block == DecoderBlockType.LLAMA4:
# Use the new helper to calculate attention TFLOPs correctly.
attention_tflops = calculate_llama4_attention_tflops(config)
# The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure.
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
elif config.decoder_block == DecoderBlockType.DEEPSEEK:
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
else:
# multiply by 3 for both feed forward and back propagation flops
learnable_weight_tflops = (
((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps
attention_tflops = attention_tflops * config.gradient_accumulation_steps
# DPO includes one additional forward pass per gradient accumulation step
if config.use_dpo:
reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass
reference_model_attention_tflops = attention_tflops / 3
attention_tflops = attention_tflops + reference_model_attention_tflops
else:
reference_model_tflops = 0
total_tflops = learnable_weight_tflops + attention_tflops + reference_model_tflops
if config.use_multimodal:
# Add vision layers TFLOPs for multimodal models
mm_total_tflops, mm_learnable_weight_tflops, mm_attention_tflops = calculate_vision_encoder_tflops(config)
if log:
print(
f"{config.model_name} vision layers per train step:\n",
f"Total TFLOPs: {mm_total_tflops:.2f} \n",
f"split as {100 * mm_learnable_weight_tflops/mm_total_tflops:.2f}% learnable weight flops",
f"and {100 * mm_attention_tflops/mm_total_tflops:.2f}% attention flops;\n",
f"learnable weight {mm_learnable_weight_tflops:.2f} TFLOPs, attention {mm_attention_tflops:.2f} TFLOPs",
)
total_tflops += mm_total_tflops
learnable_weight_tflops += mm_learnable_weight_tflops
attention_tflops += mm_attention_tflops
if log:
print(
"Per train step:\n",
f"Total TFLOPs: {total_tflops:.2f} \n",
f"split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops",
f"and {100 * attention_tflops/total_tflops:.2f}% attention flops",
)
return total_tflops, learnable_weight_tflops, attention_tflops
# https://arxiv.org/pdf/2204.02311.pdf Appendix B
def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, config, log=True):
"""Calculate training TFLOP"""
learnable_weight_tflops = 2 * num_model_parameters * prefill_length / jax.device_count() / 1e12
noncausal_attention_flops = (
4
* config.num_query_heads
* config.num_decoder_layers
* config.head_dim
* prefill_length**2
/ jax.device_count()
/ 1e12
)
causal_attention_tflops = noncausal_attention_flops / 2 # due to causality in attention
total_tflops = learnable_weight_tflops + causal_attention_tflops
if log:
print(
"Per prefill step per device: \n",
f"\tTotal TFLOPs: {total_tflops:.2f} \n",
f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops:.2f} ",
f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n",
f"\t\tCausal attention TFLOPs: {causal_attention_tflops:.2f} ",
f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total",
)
return total_tflops, learnable_weight_tflops, causal_attention_tflops
def apply_gradient_clipping(raw_grads, state, clipping_threshold):
"""Applies gradient clipping to raw gradients, with special handing for FLAX fp8 stats.
Args:
raw_grads: A pytree of raw gradients.
state: The current optimizer state.
clipping_threshold: The gradient clipping threshold.
Returns:
A pytree of clipped gradients.
"""
gradient_clip_transformation = optax.clip_by_global_norm(clipping_threshold)
if OVERWRITE_WITH_GRADIENT in raw_grads:
# Scales + Amax History for Delayed Tensor Scaling SHOULD NOT be clipped or affect clipping
fp8_stats = raw_grads.pop(OVERWRITE_WITH_GRADIENT)
grads, _ = gradient_clip_transformation.update(raw_grads, state, None)
grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands
raw_grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands
else:
grads, _ = gradient_clip_transformation.update(raw_grads, state, None)
return grads
def get_nested_value(dictionary, nested_key, default=None):
"""
Retrieves a value from a nested key in a dictionary.
Args:
dictionary: The dictionary to search in.
nested_key: A tuple representing the nested key, e.g., ('level1', 'level2', 'key').
default: The value to return if the nested key is not found.
Returns:
The value associated with the nested key, or the default value if not found.
"""
current_level = dictionary
for key in nested_key:
if not isinstance(current_level, dict) or key not in current_level:
return default
current_level = current_level[key]
return current_level
def update_state_param(state, target_path, value):
"""
Updates a specific parameter in state.params at the given path.
Args:
state: The current TrainState.
target_path: A tuple of keys matching the structure inside state.params.
value: The value to apply.
"""
def create_jax_path(target_path):
path = []
for k in target_path:
path.append(jax.tree_util.DictKey(key=k))
return tuple(path)
def _apply_update(path, param):
if path == updated_target_path:
return param + value
return param
updated_target_path = create_jax_path(target_path)
new_params = jax.tree_util.tree_map_with_path(_apply_update, state.params)
return state.replace(params=new_params)
def init_decode_state(apply_fn, params) -> train_state.TrainState:
"""Init train state with null opt state for decode."""
state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore
return state
def init_training_state(apply_fn, params, tx):
"""Init train state with null opt state for decode."""
state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx)
return state
def init_initial_state(model, tx, config, is_training, key):
"""
We pass in "static" objects like model, tx, config as JAX compares them by
object hash, and instantiating them inside causes pjit top-level annotations
to fail to match as pytree prefixes if we re-instantiate.
Args: model, tx, config, is_training, key
"""
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
config.model_name, batch_size=config.micro_batch_size_to_train_on
)
audio_shape = multimodal_utils.get_dummy_audio_shape_for_init(
config.model_name, config=config, batch_size=config.micro_batch_size_to_train_on
)
# Split the master key into independent keys for each RNG collection
# Reference: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html
params_key, dropout_key, aqt_key = jax.random.split(key, 3)
model_vars = model.init(
{"params": params_key, "dropout": dropout_key, "aqt": aqt_key},
np.ones(input_shape, dtype=jnp.int32),
np.ones(input_shape, dtype=jnp.int32),
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None,
# nnx_method="no_op",
)
if is_training:
return init_training_state(model.apply, model_vars, tx)
return init_decode_state(model.apply, model_vars)
def get_abstract_param(model, config):
"""Get abstract model structure (name, shape) without materializing the weights to save memory"""
with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
key = jax.random.PRNGKey(0)
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
config.model_name, batch_size=config.micro_batch_size_to_train_on
)
audio_shape = multimodal_utils.get_dummy_audio_shape_for_init(
config.model_name, config=config, batch_size=config.micro_batch_size_to_train_on
)
abstract_vars = jax.eval_shape(
model.init,
{"params": key, "dropout": key, "aqt": key},
jnp.ones(input_shape, dtype=jnp.int32),
jnp.ones(input_shape, dtype=jnp.int32),
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None,
)
return abstract_vars
def setup_decode_state(model, config, rng, mesh, checkpoint_manager):
"""Setup decode state by loading params from a checkpoint.
Args:
model: the flax model to initialize
config: config object
rng: jax.prng key
mesh: jax.devices() mesh
checkpoint_manager: Checkpoint manager
Returns:
state: state with decode params loaded from the checkpoint
state_mesh_annotations: the mesh annotations for the state
"""
if not config.load_parameters_path:
# generate random params
max_logging.log("No decode checkpoint specified - generating random weights.")
state, state_mesh_annotations, _, _ = setup_initial_state(
model, None, None, config, rng, mesh, checkpoint_manager, False
)
else:
# Load params from checkpoint
max_logging.log(f"Loading decode params from {config.load_parameters_path}")
unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False)
with nn_partitioning.axis_rules(config.logical_axis_rules):
params = checkpointing.load_params_from_path(
config.load_parameters_path,
unboxed_abstract_state.params,
config.checkpoint_storage_concurrent_gb,
config.checkpoint_storage_use_ocdbt,
config.checkpoint_storage_use_zarr3,
)
state = init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
return state, state_mesh_annotations
def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager):
is_training = True
return setup_initial_state(
model,
data_iterator,
tx,
config,
rng,
mesh,
checkpoint_manager,
is_training,
)
def setup_initial_state(
model,
data_iterator,
tx,
config,
rng,
mesh,
checkpoint_manager,
is_training=True,
):
"""We initialize the model and optimizer state, and optionally load from a
checkpoint as necessary.
Args:
model: the flax model to initialize
tx: the optax.GradientTransformation
config: config object
rng: jax.prng key
mesh: jax.devices() mesh
checkpoint_manager: an Orbax checkpointing.CheckpointManager object
is_training: True to initialize training state, False for decode state
Returns:
state: the initialized train state
state_mesh_annotations: the mesh annotations for the train state
"""
unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(
model, tx, config, rng, mesh, is_training
)
# Initialization
with nn_partitioning.axis_rules(config.logical_axis_rules):
restored, raw_params = checkpointing.load_state_if_possible(
checkpoint_manager,
data_iterator,
config.load_parameters_path,