-
Notifications
You must be signed in to change notification settings - Fork 510
Expand file tree
/
Copy pathtypes.py
More file actions
2669 lines (2324 loc) · 124 KB
/
types.py
File metadata and controls
2669 lines (2324 loc) · 124 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023–2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pydantic-based configuration system for MaxText, organized into modular classes."""
# pylint: disable=too-many-lines
import datetime
import enum
from enum import Enum
from jinja2 import Environment, TemplateSyntaxError
import logging
import math
from math import prod
import os
from tempfile import gettempdir
import yaml
from typing import Any, Literal, NewType, Optional
import jax
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
from maxtext.utils import gcs_utils
from maxtext.utils import max_utils
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
from maxtext.utils import accelerator_to_spec_map
from pydantic.config import ConfigDict
from pydantic.fields import Field
from pydantic.functional_validators import field_validator, model_validator
from pydantic.main import BaseModel
from pydantic.types import NonNegativeFloat, NonNegativeInt, PositiveInt
class XProfTPUPowerTraceMode(enum.IntEnum): # pylint: disable=invalid-name
"""Enum for XProfTPUPowerTraceMode."""
POWER_TRACE_NONE = 0
POWER_TRACE_NORMAL = 1
POWER_TRACE_SPI = 2
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------
# Reusable Enums and Type Aliases
# ----------------------------------------------------------------------------
PathStr = str
AxisNames = NewType("AxisNames", str)
class DType(str, Enum):
"""Supported data types for weights and activations."""
BFLOAT16 = "bfloat16"
FLOAT32 = "float32"
FLOAT16 = "float16"
class MatmulPrecision(str, Enum):
"""Precision levels for matrix multiplications."""
DEFAULT = "default"
HIGH = "high"
HIGHEST = "highest"
# same as default
BFLOAT16 = "bfloat16"
# same as highest
FLOAT32 = "float32"
class QuantizationType(str, Enum):
"""Supported quantization schemes."""
NONE = ""
INT4 = "int4"
INT8 = "int8"
INTMP = "intmp"
FP8 = "fp8"
NANOO_FP8 = "nanoo_fp8"
FP8_NANO_V2 = "fp8_nanoo"
FP8_GPU = "fp8_gpu"
FP8_FULL = "fp8_full"
TE_FP8_DS = "te_fp8_delayedscaling"
TE_FP8_CS = "te_fp8_currentscaling"
TE_MXFP8 = "te_mxfp8"
TE_NVFP4 = "te_nvfp4"
TE_NVFP4_NO_RHT = "te_nvfp4_no_rht"
class KvQuantAxis(str, Enum):
"""Axes to quantize over for the Key-Value cache."""
NONE = ""
DKV = "dkv"
HEADS_AND_DKV = "heads_and_dkv"
class RematPolicy(str, Enum):
"""Available rematerialization (gradient checkpointing) policies."""
FULL = "full"
MINIMAL = "minimal"
SAVE_DOT_WITH_CONTEXT_EXCEPT_MLP = "save_dot_with_context_except_mlp"
SAVE_DOT_EXCEPT_MLPWI = "save_dot_except_mlpwi"
SAVE_DOT_EXCEPT_MLP = "save_dot_except_mlp"
SAVE_QKV_PROJ = "save_qkv_proj"
QKV_PROJ_OFFLOADED = "qkv_proj_offloaded"
CUSTOM = "custom"
MINIMAL_OFFLOADED = "minimal_offloaded"
SAVE_OUT_PROJ = "save_out_proj"
class RematLocation(str, Enum):
"""Specifies where to store activations for rematerialization."""
REMAT = "remat"
DEVICE = "device"
OFFLOAD = "offload"
class OptimizerType(str, Enum):
"""Supported optimizer algorithms."""
ADAMW = "adamw"
ADAM_PAX = "adam_pax"
SGD = "sgd"
MUON = "muon"
class LearningRateScheduleType(str, Enum):
"""Supported learning rate schedule types."""
COSINE = "cosine"
WSD = "wsd"
class WsdDecayStyle(str, Enum):
"""Supported decay styles for WSD schedule."""
LINEAR = "linear"
COSINE = "cosine"
class RopeType(str, Enum):
"""Supported Rotary Positional Embedding (RoPE) implementations."""
DEFAULT = "default"
LLAMA3_1 = "llama3.1"
YARN = "yarn"
class TokenizerType(str, Enum):
"""Supported tokenizer libraries."""
SENTENCEPIECE = "sentencepiece"
HUGGINGFACE = "huggingface"
TIKTOKEN = "tiktoken"
class DatasetType(str, Enum):
"""Supported data loading pipelines."""
SYNTHETIC = "synthetic"
HF = "hf"
GRAIN = "grain"
TFDS = "tfds"
C4MLPERF = "c4_mlperf"
class SamplingStrategy(str, Enum):
"""Supported decoding and sampling strategies."""
GREEDY = "greedy"
WEIGHTED = "weighted"
NUCLEUS = "nucleus"
TOPK = "topk"
COMPOSITE = "composite"
class ProfilerType(str, Enum):
"""Supported performance profilers."""
NONE = ""
XPLANE = "xplane"
NSYS = "nsys"
# ----------------------------------------------------------------------------
# Pydantic models for configuration
# ----------------------------------------------------------------------------
ModelName = Literal[
"default",
"llama2-7b",
"llama2-13b",
"llama2-70b",
"llama3-8b",
"llama3.1-8b-Instruct",
"llama3-70b",
"llama3.1-70b-Instruct",
"llama3.1-8b",
"llama3.1-70b",
"llama3.1-405b",
"llama3.3-70b",
"mistral-7b",
"mixtral-8x7b",
"mixtral-8x22b",
"deepseek2-16b",
"deepseek2-236b",
"deepseek3-671b",
"deepseek3-671b-2dfsdp",
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
"gemma-2b",
"gemma2-2b",
"gemma2-9b",
"gemma2-27b",
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"qwen2.5-7b",
"qwen2.5-14b",
"qwen3-0.6b",
"qwen3-1.7b",
"qwen3-1.7b-base",
"qwen3-4b",
"qwen3-4b-base",
"qwen3-4b-thinking-2507",
"qwen3-8b",
"qwen3-8b-base",
"qwen3-14b",
"qwen3-14b-base",
"qwen3-32b",
"qwen3-235b-a22b",
"qwen3-30b-a3b",
"qwen3-30b-a3b-base",
"qwen3-480b-a35b",
"qwen3-next-80b-a3b",
"qwen3-omni-30b-a3b",
"gpt3-175b",
"gpt3-22b",
"gpt3-6b",
"gpt3-52k",
"gpt-oss-20b",
"gpt-oss-120b",
"llama4-17b-16e",
"llama4-17b-128e",
"olmo3-7b",
"olmo3-7b-pt",
"olmo3-32b",
]
class RunInfo(BaseModel):
"""Configuration for the overall run, model identity, and logging."""
base_config: None | str = Field(
None,
description="Base config to inherit from. This is a meta-field and is consumed by the config loading system.",
)
run_name: str = Field(
"",
description="The name of the run. Checkpoints will be stored under this name.",
)
model_name: ModelName = Field("default", description="The name of the model configuration to use.")
override_model_config: bool = Field(False, description="If True, allows overriding model parameters via CLI.")
override_logical_axis_rules: bool = Field(
False,
description="If True, logical_axis_rules will be overridden instead of merged.",
)
log_config: bool = Field(
True,
description="If True, prints the final configuration after initialization.",
)
debug_sharding: bool = Field(False, description="If True, print model weight sharding details.")
base_output_directory: PathStr = Field("", description="Base directory for all outputs, typically a GCS path.")
sharding_strategy: None | Literal["experimental"] = Field(
None,
description="Experimental sharding strategy used for some inference configs.",
)
class Checkpointing(BaseModel):
"""Core configuration for checkpointing and run restoration."""
load_parameters_path: PathStr = Field("", description="Loads only model parameters from a specific checkpoint path.")
lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.")
load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.")
enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.")
load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.")
async_checkpointing: bool = Field(True, description="If True, uses an asynchronous checkpointer for performance.")
checkpoint_period: int = Field(10_000, description="The frequency (in steps) at which to save checkpoints.")
max_num_checkpoints_to_keep: int | None = Field(None, description="Maximum number of checkpoints to keep.")
enable_single_replica_ckpt_restoring: bool = Field(
False, description="One replica reads and broadcasts the checkpoint."
)
force_unroll: bool = Field(
False,
description="During param-only checkpoint generation, whether to unroll the loop.",
)
checkpoint_is_quantized: bool = Field(
False,
description="Set to True if reading from a saved AQT quantized checkpoint.",
)
save_quantized_params_path: PathStr = Field("", description="Path to save params quantized on the fly.")
enable_orbax_v1: bool = Field(False, description="Bool flag for enabling Orbax v1.")
checkpoint_conversion_fn: None | str = Field(None, description="Function for processing loaded checkpoint dict.")
source_checkpoint_layout: Literal["orbax", "safetensors"] = Field(
"orbax", description="The layout of the source checkpoint to load."
)
save_checkpoint_on_completion: bool = Field(
True, description="If True, saves a final checkpoint upon training completion."
)
enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.")
colocated_python_checkpointing: bool = Field(
False,
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
)
class OrbaxStorage(BaseModel):
"""Configuration for Orbax checkpoint storage options."""
checkpoint_storage_target_data_file_size_bytes: int = Field(
2147483648, description="Target file size for chunking large arrays in Orbax."
)
checkpoint_storage_use_ocdbt: bool = Field(True, description="Whether to use the OCDbT storage format for checkpoints.")
checkpoint_storage_use_zarr3: bool = Field(
True, description="Whether to use Zarr3 with OCDbT. Requires use_ocdbt=True."
)
checkpoint_storage_concurrent_gb: int = Field(96, description="Concurrent GB for I/O operations during checkpointing.")
class EmergencyCheckpointing(BaseModel):
"""Configuration for emergency (local) checkpointing."""
enable_multi_tier_checkpointing: bool = Field(
False, description="Enables multi-tier checkpointing (local and persistent)."
)
local_checkpoint_directory: PathStr = Field("", description="Local directory for emergency checkpoints.")
local_checkpoint_period: NonNegativeInt = Field(0, description="Frequency (in steps) for local emergency checkpoints.")
multi_tier_checkpointing_backup_interval_minutes: NonNegativeInt = Field(
0,
description="Interval in minutes to back up local checkpoints to persistent storage.",
)
mtc_data_parallelism: int = Field(
0,
description="Number of identical pipelines in the job for multi-tier checkpointing. 0 defaults to num_slices.",
)
enable_emergency_checkpoint: bool = Field(
False,
description="Legacy flag for enabling emergency checkpointing. Prefer `enable_multi_tier_checkpointing`.",
)
use_replicator_service: bool = Field(
False,
description="Whether to use emergency checkpointing with the replicator service.",
)
replicator_backup_interval_minutes: NonNegativeInt = Field(
0, description="Interval in minutes to back up local checkpoints."
)
class DataTypes(BaseModel):
"""Configuration for data types and precision."""
dtype: DType = Field(DType.BFLOAT16, description="The data type for activations.")
grad_dtype: DType = Field(DType.FLOAT32, description="The data type for gradients.")
weight_dtype: DType = Field(DType.FLOAT32, description="The data type for model weights.")
matmul_precision: MatmulPrecision = Field(
MatmulPrecision.DEFAULT,
description="Precision level for matrix multiplications.",
)
activations_in_float32: bool = Field(
False,
description="If True, sets activations to float32 before the nonlinearity.",
)
dtype_mm: str = Field("float32", description="Data type for multimodal model's vision encoder")
class Quantization(BaseModel):
"""Configuration for model quantization."""
quantization: None | QuantizationType = Field(
QuantizationType.NONE,
description="Activates quantization for transformer layers.",
)
replicate_quant_scale: bool = Field(
False,
description="Replicates quantization scale to avoid inefficient XLA fusion.",
)
quant_cfg_path: PathStr = Field("", description="Path to the configuration file for 'intmp' quantization.")
quantize_kvcache: bool = Field(False, description="If True, quantizes the Key-Value cache.")
kv_quant_axis: KvQuantAxis = Field(KvQuantAxis.HEADS_AND_DKV, description="Axes to quantize over for the KV cache.")
kv_quant_dtype: Literal["int8", "int4"] = Field("int8", description="Data type for KV cache quantization.")
quantization_local_shard_count: int = Field(-1, description="Shards the range finding operation for quantization.")
use_qwix_quantization: bool = Field(False, description="Whether to use qwix for quantization.")
weight_quantization_calibration_method: str = Field(
"absmax",
description="Quantization calibration method used for weights.",
)
act_quantization_calibration_method: str = Field(
"absmax",
description="Quantization calibration method used for activations.",
)
bwd_quantization_calibration_method: str = Field(
"absmax",
description="Quantization calibration method used for gradients.",
)
class ModelArchitecture(BaseModel):
"""Core model architecture parameters."""
decoder_block: DecoderBlockType = Field(
"llama2",
description="The style of DecoderBlock to use (e.g., 'llama2', 'gemma').",
)
global_parameter_scale: int = Field(1, description="A global scaling factor for model dimensions.")
base_emb_dim: int = Field(2048, description="Base embedding dimension.")
base_num_query_heads: int = Field(16, description="Base number of query heads.")
base_num_kv_heads: int = Field(16, description="Base number of key/value heads.")
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
head_dim: int = Field(128, description="Dimension of each attention head.")
mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.")
mlp_activations_limit: float = Field(
-1.0,
description="Upper bound to clip the MLP activation values. -1.0 means no clipping.",
)
normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.")
fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.")
attention_bias: bool = Field(
False,
description="If True, adds a learnable bias to the query, key, and value projections.",
)
fused_mlp: bool = Field(False, description="If supported, fuse the MLP layers.")
class MTP(BaseModel):
"""Multi-Token Prediction Configs."""
mtp_num_layers: NonNegativeInt = Field(0, description="The number of auxiliary prediction layers to use for MTP.")
mtp_loss_scaling_factor: NonNegativeFloat = Field(
0.1,
description="The scaling factor (lambda) for the MTP auxiliary loss.",
)
mtp_eval_target_module: NonNegativeInt = Field(
0,
description="Specifies which MTP layer is used to calculate metrics.",
)
class Logits(BaseModel):
"""Configuration for the final logits computation."""
logits_via_embedding: bool = Field(False, description="If True, tie the embedding and unembedding matrices.")
normalize_embedding_logits: bool = Field(
True,
description="If logits_via_embedding is true, normalize pre-softmax logits.",
)
logits_dot_in_fp32: bool = Field(False, description="Use fp32 for the logits dot product for stability.")
cast_logits_to_fp32: bool = Field(True, description="Whether to cast the final logits to fp32.")
final_logits_soft_cap: None | NonNegativeFloat = Field(
None,
description="Soft-cap value for the final logits. None or 0.0 means no cap.",
)
z_loss_multiplier: float = Field(0.0, description="The multiplier for the z-loss (e.g., 1e-4). 0.0 to disable.")
class Attention(BaseModel):
"""General configuration for the attention mechanism."""
attention: str = Field(
"autoselected",
description="The attention algorithm to use (dot_product, flash, etc).",
)
share_kv_projections: bool = Field(False, description="If True, Key and Value use the same projection.")
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field(
"global", description="The variant of attention to use."
)
attention_sink: bool = Field(False, description="If True, enables attention sinks.")
float32_qk_product: bool = Field(False, description="In dot-product attention, cast query-key product to fp32.")
float32_logits: bool = Field(
False,
description="In dot-product attention, cast logits to fp32 before softmax.",
)
sliding_window_size: NonNegativeInt = Field(0, description="The size of the sliding window for local attention.")
chunk_attn_window_size: NonNegativeInt = Field(0, description="The window size for chunked attention.")
attn_logits_soft_cap: None | NonNegativeFloat = Field(
None, description="Soft-cap value for attention logits. None means no cap."
)
use_post_attn_norm: bool = Field(False, description="Apply LayerNorm after the attention block.")
use_post_ffw_norm: bool = Field(False, description="Apply LayerNorm after the feed-forward block.")
use_ragged_attention: bool = Field(False, description="Whether to use ragged attention kernels.")
use_tokamax_gmm: bool = Field(
False,
description="Whether to use the Tokamax library for GMM kernel implementation.",
)
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.")
force_q_layout: bool = Field(False, description="Force the Q layout")
use_qk_clip: bool = Field(False, description="Whether to use QK-Clip (MuonClip) for training stability.")
qk_clip_threshold: float = Field(100.0, description="Threshold for QK-Clip (tau).")
class MoBa(BaseModel):
"""Configuration for Mixture of Block Attention (MoBA)."""
moba: bool = Field(False, description="If True, enables Mixture of Block Attention.")
moba_chunk_size: int = Field(1024, description="The chunk size for MoBA.")
moba_topk: int = Field(8, description="The number of top-k chunks to select in MoBA.")
class MlaAttention(BaseModel):
"""Configuration for Multi-Layer Attention (MLA)."""
mla_naive_kvcache: bool = Field(True, description="Whether to use naive kvcache for MLA attention.")
q_lora_rank: NonNegativeInt = Field(0, description="Query LoRA rank for MLA.")
kv_lora_rank: NonNegativeInt = Field(512, description="Key/Value LoRA rank for MLA.")
qk_nope_head_dim: NonNegativeInt = Field(128, description="Dimension for non-RoPE part of QK heads in MLA.")
qk_rope_head_dim: NonNegativeInt = Field(64, description="Dimension for RoPE part of QK heads in MLA.")
v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.")
class AttentionIndexer(BaseModel):
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""
use_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
indexer_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
indexer_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
indexer_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
indexer_sparse_training: bool = Field(
False,
description="Determines the training strategy for the indexer: Dense Warm-up or Sparse Training stage.",
)
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")
class Llama4Attention(BaseModel):
"""Configuration specific to Llama4-style models."""
use_qk_norm: bool = Field(
False,
description="Whether to apply L2 normalization to Query/Key vectors after RoPE.",
)
temperature_tuning: bool = Field(
False,
description="Dynamically scale attention temperature based on sequence length.",
)
class SplashAttention(BaseModel):
"""Tunable block sizes for Splash Attention kernels."""
sa_block_q: int = Field(512, description="Block size for Q in splash attention.")
sa_block_kv: int = Field(512, description="Block size for KV in splash attention.")
sa_block_kv_compute: int = Field(512, description="Block size for KV compute in splash attention.")
sa_block_q_dkv: int = Field(512, description="Block size for Q_dkv in splash attention.")
sa_block_kv_dkv: int = Field(512, description="Block size for KV_dkv in splash attention.")
sa_block_kv_dkv_compute: int = Field(512, description="Block size for KV_dkv compute in splash attention.")
sa_block_q_dq: int = Field(512, description="Block size for Q_dq in splash attention.")
sa_block_kv_dq: int = Field(512, description="Block size for KV_dq in splash attention.")
sa_use_fused_bwd_kernel: bool = Field(False, description="Use fused backward kernel in splash attention.")
sa_q_layout: str = Field("HEAD_DIM_MINOR", description="Layout for Q in splash attention.")
sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in splash attention.")
sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in splash attention.")
use_max_logit_estimate: int = Field(
-1,
description="-1 means no estimate, any > 0 value will be used as max logit estimate",
)
cost_estimate_flops_fwd: int = Field(
-1,
description="-1 means using splash default cost estimation, any >= 0 value will be used as cost estimation for splash "
"to overlap for communication (forward)",
)
cost_estimate_flops_bwd: int = Field(
-1,
description="-1 means using splash default cost estimation, any >= 0 value will be used as cost estimation for splash "
"to overlap for communication (backward)",
)
dq_reduction_steps: int = Field(
0,
description="the number of reduction steps. For now, only 3 or all the kv steps are supported.",
)
use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.")
class PagedAttention(BaseModel):
"""Tunable parameters for Paged Attention kernels."""
pagedattn_num_pages: int = Field(64, description="Total number of pages to allocate for paged attention.")
pagedattn_tokens_per_page: int = Field(32, description="Number of tokens each page can hold.")
pagedattn_pages_per_compute_block: int = Field(4, description="Number of pages processed together in pallas kernels.")
pagedattn_max_pages_per_group: int = Field(-1, description="Max pages per request; -1 defaults to max_target_length.")
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
# TPUs, the head_dim is padded to the nearest multiple of 128.
pagedattn_head_dim_alignment: int = Field(128, description="Alignment of head_dim to the nearest multiple.")
class MoEGeneral(BaseModel):
"""General configuration for Mixture of Experts (MoE) layers."""
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(
True,
description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul.",
)
use_ring_of_experts: bool = Field(
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
"fsdp",
description="How the expert axis is used to shard attention weights and activations.",
)
moe_fsdp_use_two_stage_all_gather: bool = Field(
False,
description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.",
)
shard_exp_on_fsdp: bool = Field(
False,
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
"and recommended only when num_experts is a multiple of fsdp_parallelism",
)
use_2d_fsdp_sharding: bool = Field(
False,
description="Use `fsdp` and `fsdp_transpose` axes for 2D FSDP sharding.",
)
norm_topk_prob: bool = Field(
False,
description="Enable top-k probability normalization for router weights (Qwen3-specific).",
)
float32_weight_sum: bool = Field(
True,
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
)
class MoEKernels(BaseModel):
"""Configuration for MoE-specific kernels like Megablox."""
megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.")
sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.")
wi_tile_fwd_batch_seq: int = Field(
512,
description="forward pass tiling dimension for batch/sequence in GMM for wi.",
)
wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.")
wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.")
wi_tile_dlhs_batch_seq: int = Field(
512,
description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wi.",
)
wi_tile_dlhs_embed_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for embedding in GMM for wi.")
wi_tile_dlhs_mlp_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for MLP in GMM for wi.")
wi_tile_drhs_batch_seq: int = Field(
512,
description="bwd pass drhs tiling dimension for batch/sequence in GMM for wi.",
)
wi_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wi.")
wi_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wi.")
wo_tile_fwd_batch_seq: int = Field(
512,
description="forward pass tiling dimension for batch/sequence in GMM for wo.",
)
wo_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wo.")
wo_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wo.")
wo_tile_dlhs_batch_seq: int = Field(
512,
description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wo.",
)
wo_tile_dlhs_embed_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for embedding in GMM for wo.")
wo_tile_dlhs_mlp_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for MLP in GMM for wo.")
wo_tile_drhs_batch_seq: int = Field(
512,
description="bwd pass drhs tiling dimension for batch/sequence in GMM for wo.",
)
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")
wi_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wi.")
wi_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wi.")
wi_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wi.")
wo_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wo.")
wo_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wo.")
wo_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wo.")
wi_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wi.")
wo_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wo.")
merge_gating_gmm: bool = Field(False, description="whether to merge the two gating gmm kernels into one.")
class DeepSeekMoE(BaseModel):
"""Configuration specific to DeepSeek-style MoE layers."""
base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")
routed_bias_update_rate: float = Field(0.0, description="Update rate applied to the router bias term.")
mlp_bias: bool = Field(
False,
description="Whether to add a learnable bias for MLP matmul, "
"and originally implemented to support the GPT-OSS model architecture",
)
n_routing_groups: int = Field(-1, description="Number of groups for routing, disabled by default.")
topk_routing_group: int = Field(-1, description="Number of top groups to route inputs to.")
use_batch_split_schedule: bool = Field(
False,
description="Whether to split batch into micro-batches to hide communications that yields performance benefits.",
)
batch_split_factor: int = Field(
1,
description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.",
)
class Qwen3Next(BaseModel):
"""Configuration specific to Qwen3-Next models with Gated Delta Net."""
gdn_conv_kernel_dim: int = Field(4, description="Kernel size for the 1D convolution in the Gated Delta Net.")
gdn_key_head_dim: int = Field(128, description="Head dimension for the key/query in the Gated Delta Net.")
gdn_value_head_dim: int = Field(128, description="Head dimension for the value in the Gated Delta Net.")
gdn_num_key_heads: int = Field(16, description="Number of key/query heads in the Gated Delta Net.")
gdn_num_value_heads: int = Field(32, description="Number of value heads in the Gated Delta Net.")
gdn_chunk_size: int = Field(
64,
description="Chunk size for the parallel scan algorithm in the Gated Delta Net.",
)
use_qk_norm_in_gdn: bool = Field(
True,
description="Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel.",
)
partial_rotary_factor: float = Field(1.0, description="The ratio of dimension to apply ROPE on")
class HardwareAndMesh(BaseModel):
"""Configuration for hardware and parallelism mesh."""
hardware: Literal["tpu", "gpu", "gpu_multiprocess", "cpu"] = Field("tpu", description="The type of hardware to run on.")
num_slices: int = Field(-1, description="Number of TPU slices. Automatically determined.")
mesh_axes: list[str] = Field(
[
"data",
"stage",
"fsdp",
"fsdp_transpose",
"sequence",
"context",
"context_autoregressive",
"tensor",
"tensor_transpose",
"tensor_sequence",
"expert",
"autoregressive",
],
description="The names of the axes in the logical device mesh.",
)
shard_mode: ShardMode = Field("auto", description="can be either auto or explicit")
inhomogeneous_layer_cycle_interval: int = Field(1, description="The interval of repeated inhomogeneous layer patterns.")
scan_layers: bool = Field(True, description="Whether to use jax.lax.scan over layers.")
param_scan_axis: int = Field(1, description="Axis to scan over for parameters.")
context_parallel_load_balance: bool = Field(True, description="Whether to use load balancing for context parallelism.")
context_parallel_strategy: str = Field(
"all_gather",
description="Strategy for context parallelism ('all_gather' or 'ring').",
)
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")
custom_mesh_and_rule: str = Field("", description="Customized mesh and logical rules for granularity.")
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
class LayoutAndSharding(BaseModel):
"""Configuration for data and model sharding rules."""
logical_axis_rules: Any = Field([], description="Rules for mapping logical axes to physical mesh axes.")
data_sharding: Any = Field([], description="Sharding for input data.")
input_data_sharding_logical_axes: list[str] = Field(
["activation_embed_and_logits_batch", "activation_norm_length"],
description="Logical axes for sharding input data.",
)
sharding_tolerance: float = Field(
0.02,
ge=0.0,
le=1.0,
description="Allowed percentage of non-sharded parameters.",
)
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
class DcnParallelism(BaseModel):
"""Parallelism dimensions across the DCN (Data Center Network)."""
dcn_diloco_parallelism: int = Field(1, description="DCN axis for Diloco parallelism.")
dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.")
dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.")
dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.")
dcn_sequence_parallelism: int = Field(1, description="DCN axis for sequence parallelism (not recommended).")
dcn_context_parallelism: int = Field(1, description="DCN axis for context parallelism.")
dcn_context_autoregressive_parallelism: int = Field(1, description="DCN axis for context autoregressive parallelism.")
dcn_tensor_parallelism: int = Field(1, description="DCN axis for tensor parallelism (not recommended).")
dcn_tensor_transpose_parallelism: int = Field(1, description="DCN axis for tensor transpose parallelism.")
dcn_tensor_sequence_parallelism: int = Field(
1, description="DCN axis for tensor sequence parallelism (not recommended)."
)
dcn_pipeline_parallelism: int = Field(1, description="DCN axis for pipeline parallelism.")
dcn_expert_parallelism: int = Field(1, description="DCN axis for expert parallelism.")
dcn_autoregressive_parallelism: int = Field(1, description="DCN axis for autoregressive parallelism (not recommended).")
class IciParallelism(BaseModel):
"""Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""
ici_diloco_parallelism: int = Field(1, description="ICI axis for Diloco parallelism.")
ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.")
ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.")
ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.")
ici_sequence_parallelism: int = Field(1, description="ICI axis for sequence parallelism.")
ici_context_parallelism: int = Field(1, description="ICI axis for context parallelism.")
ici_context_autoregressive_parallelism: int = Field(1, description="ICI axis for context autoregressive parallelism.")
ici_tensor_parallelism: int = Field(1, description="ICI axis for tensor parallelism.")
ici_tensor_transpose_parallelism: int = Field(1, description="ICI axis for tensor transpose parallelism.")
ici_tensor_sequence_parallelism: int = Field(1, description="ICI axis for tensor sequence parallelism.")
ici_autoregressive_parallelism: int = Field(1, description="ICI axis for autoregressive parallelism.")
ici_pipeline_parallelism: int = Field(1, description="ICI axis for pipeline parallelism.")
ici_expert_parallelism: int = Field(1, description="ICI axis for expert parallelism.")
class PipelineParallelism(BaseModel):
"""Configuration for pipeline parallelism."""
pipeline_fsdp_ag_per_repeat: bool = Field(
False, description="Enable weight prefetching for circular pipeline parallelism."
)
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
num_pipeline_repeats: int = Field(
-1,
description="Number of pipeline repeats. Calculated from other params if -1.",
)
pipeline_parallel_layers: int = Field(-1, description="Number of layers to pipeline. -1 pipelines all decoder layers.")
num_pipeline_microbatches: int = Field(
-1,
description="Number of microbatches for the pipeline. -1 defaults to num_stages.",
)
pipeline_delay_activation_forwarding: bool = Field(
False, description="Delays activation forwarding to aid XLA optimization."
)
pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.")
scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.")
scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.")
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")
class RematAndOffload(BaseModel):
"""Configuration for gradient checkpointing (rematerialization) and offloading."""
remat_policy: str = Field(
RematPolicy.FULL.value,
description="The rematerialization policy, trading off speed and memory.",
)
remat_policy_for_vit: str = Field("minimal", description="Remat policy for multimodal model's vision encoder.")
decoder_layer_input: RematLocation = Field(
RematLocation.DEVICE, description="Remat policy for the decoder layer's input."
)
context: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the attention context.")
mlpwi: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the first MLP layer's intermediate output.",
)
mlpwi_0: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the first part of a gated MLP's output.",
)
mlpwi_1: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the second part of a gated MLP's output.",
)
mlpwo: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the second MLP layer's output.",
)
query_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the query projection.")
key_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the key projection.")
value_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the value projection.")
qkv_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for fused QKV projection.")
out_proj: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the attention output projection.",
)
mla_q: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the mla's query projectiont.",
)
mla_kv: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the mla's key and value projection.",
)
attention_out: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the attention output.",
)
engram: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the engram output.")
optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.")
parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.")
class Tokenizer(BaseModel):
"""Configuration for the tokenizer."""
vocab_size: int = Field(32_000, description="The size of the vocabulary.")
tokenizer_path: None | PathStr = Field(
None,
description="Path to the tokenizer model file.",
)
tokenizer_type: TokenizerType = Field(TokenizerType.SENTENCEPIECE, description="The type of tokenizer.")
use_chat_template: bool = Field(False, description="Whether to use the chat template for tokenization.")
chat_template_path: str = Field("", description="Path to chat template json file.")
chat_template: str = Field(
"", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template."
)
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")
add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.")
add_eos: bool = Field(True, description="Whether to add an end-of-sentence token.")
use_truncation: bool = Field(
True,
description="If False, use chunking for long sequences instead of truncation.",
)
num_vocab_tiling: int = Field(
1,
description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.",
)
class DatasetGeneral(BaseModel):
"""General configuration for dataset and data loading."""
dataset_type: DatasetType = Field(DatasetType.TFDS, description="The type of the data loading pipeline.")
per_device_batch_size: int | float = Field(12, description="The batch size per device.")
eval_per_device_batch_size: int | float = Field(
0.0,
description="The batch size per device for evaluation. Defaults to per_device_batch_size.",
)
max_corpus_chars: int = Field(10_000_000, description="Maximum number of characters to use from the corpus.")
train_data_columns: list[str] = Field(["text"], description="Column(s) to use from the training data.")
train_image_column: str | list[str] = Field("image", description="Column name(s) for images in the training data.")
eval_data_columns: list[str] = Field(["text"], description="Column(s) to use from the evaluation data.")
eval_image_column: str | list[str] = Field("image", description="Column name(s) for images in evaluation data.")
packing: bool = Field(
True,
description="Whether to pack multiple short examples into a single sequence.",
)
grain_packing_type: Literal["first_fit", "best_fit", "concat_then_split"] = Field(
"first_fit",
description="Packing type when using Grain pipeline. 'first_fit', 'best_fit' or 'concat_then_split'.",
)
max_segments_per_seq: int = Field(
-1,
description="Maximum number of segments that can be packed into a single sequence. -1 or None for no limit.",
)
num_epoch: int = Field(1, description="Number of epochs to train for.")
expansion_factor_real_data: float = Field(-1.0, description="Factor for partial data loading on hosts.")
reuse_example_batch: int = Field(0, description="For performance testing, repeatedly uses the same batch.")
generate_padding_batch_train: bool = Field(
False,
description="Whether to generate a padding batch for training to ensure divisibility.",