Skip to content

Commit 04c3e3d

Browse files
committed
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into shuningjin-ckpt-opt2
2 parents d469717 + 8309e22 commit 04c3e3d

112 files changed

Lines changed: 214 additions & 217 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ repos:
3131
- '--disable=R0401,R0917,W0201,W0613'
3232
- "--ignore-patterns='.pytype,.*pyi$'"
3333
- 'benchmarks'
34-
- 'end_to_end'
3534
- 'src'
3635
- 'tests'
3736

docs/reference/models/supported_models_and_architectures.md

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,89 +8,87 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX.
88

99
**Key capabilities and features**:
1010

11-
* **Supported Precisions**: FP32, BF16, INT8, and FP8.
12-
* **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13-
* **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md).
14-
* **Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
15-
* **Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction.
16-
* **Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
17-
* **Flexible Remat Policy**: Provides fine-grained control over memory-compute trade-offs. Users can select pre-defined policies (like 'full' or 'minimal') or set the policy to **'custom'**.
18-
11+
- **Supported Precisions**: FP32, BF16, INT8, and FP8.
12+
- **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13+
- **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md).
14+
- **Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
15+
- **Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction.
16+
- **Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
17+
- **Flexible Remat Policy**: Provides fine-grained control over memory-compute trade-offs. Users can select pre-defined policies (like 'full' or 'minimal') or set the policy to **'custom'**.
1918

2019
## Supported model families
2120

2221
> _**Note on GPU Coverage**: Support and tested configurations for NVIDIA GPUs can vary by model family. Please see the specific model guides for details._
2322
2423
**Primary Platforms**: All model families listed below target **TPU** and **NVIDIA GPUs**.
2524

26-
2725
### Llama
2826

29-
* **Variants**: Llama 2; **Llama 3 / 3.1 / 3.3**; Llama 4 (**Scout**, **Maverick**; text & multimodal)
30-
* **Notes**: RoPE, RMSNorm, SwiGLU; GQA; routed experts (Llama 4); **QK-Norm** (Llama 4); multimodal projector & vision encoder.
27+
- **Variants**: Llama 2; **Llama 3 / 3.1 / 3.3**; Llama 4 (**Scout**, **Maverick**; text & multimodal)
28+
- **Notes**: RoPE, RMSNorm, SwiGLU; GQA; routed experts (Llama 4); **QK-Norm** (Llama 4); multimodal projector & vision encoder.
3129

3230
### Mistral / Mixtral
3331

34-
* **Variants**: Mistral (dense); Mixtral 8×7B, 8×22B (MoE)
35-
* **Notes**: Sliding-Window Attention (SWA), GQA; MoE top-k with load-balancing loss.
32+
- **Variants**: Mistral (dense); Mixtral 8×7B, 8×22B (MoE)
33+
- **Notes**: Sliding-Window Attention (SWA), GQA; MoE top-k with load-balancing loss.
3634

3735
### Gemma
3836

39-
* **Variants**: Gemma 1 (2B/7B), Gemma 2 (2B/9B/27B), **Gemma 3 (4B/12B/27B)** (text & multimodal)
40-
* **Notes**: RMSNorm; RoPE; GELU/SwiGLU; **QK-Norm** (Gemma 3); Local–Global interleaved attention; long-context scaling.
37+
- **Variants**: Gemma 1 (2B/7B), Gemma 2 (2B/9B/27B), **Gemma 3 (4B/12B/27B)** (text & multimodal)
38+
- **Notes**: RMSNorm; RoPE; GELU/SwiGLU; **QK-Norm** (Gemma 3); Local–Global interleaved attention; long-context scaling.
4139

4240
### DeepSeek
4341

44-
* **Variants**: V2 (16B, 236B), V3 (671B), R1
45-
* **Notes**: MLA; shared/finer-grained experts; MTP; YaRN-style scaling.
42+
- **Variants**: V2 (16B, 236B), V3 (671B), R1
43+
- **Notes**: MLA; shared/finer-grained experts; MTP; YaRN-style scaling.
4644

4745
### Qwen3
4846

49-
* **Variants**: Dense (0.6B–32B); MoE (30B-A3B, 235B-A22B, 480B Coder)
50-
* **Notes**: **QK-Norm**, GQA, SwiGLU, RMSNorm, RoPE.
47+
- **Variants**: Dense (0.6B–32B); MoE (30B-A3B, 235B-A22B, 480B Coder)
48+
- **Notes**: **QK-Norm**, GQA, SwiGLU, RMSNorm, RoPE.
5149

5250
### GPT-OSS
5351

54-
* **Variants**: 20B, 120B
55-
* **Notes**: Local–Global interleaved attention, GQA, attention sink; YaRN-style scaling; MoE.
52+
- **Variants**: 20B, 120B
53+
- **Notes**: Local–Global interleaved attention, GQA, attention sink; YaRN-style scaling; MoE.
5654

5755
## Parallelism building blocks
5856

5957
MaxText supports a wide range of parallelism strategies for scaling training and inference across TPUs and GPUs:
6058

61-
* **FSDP (Fully Sharded Data Parallel)**: Reduces memory footprint by sharding parameters and optimizer states across devices.
62-
* **TP (Tensor Parallelism)**: Splits tensor computations (e.g., matrix multiplications, attention heads) across devices for intra-layer speedups.
63-
* **EP (Expert Parallelism)**: Distributes MoE experts across devices, supporting dropless routing and load balancing to ensure efficient utilization.
64-
* **DP (Data Parallelism)**: Replicates the model across devices while splitting the input data batches.
65-
* **PP (Pipeline Parallelism)**: Splits layers across device stages to support extremely large models by managing inter-stage communication.
66-
* **CP (Context Parallelism)**: Splits sequence tokens across devices, complementing tensor parallelism for long-context workloads.
67-
* **Hybrid Parallelism**: Allows for flexible combinations of FSDP, TP, EP, DP, PP, and CP to maximize hardware utilization based on model size and topology.
59+
- **FSDP (Fully Sharded Data Parallel)**: Reduces memory footprint by sharding parameters and optimizer states across devices.
60+
- **TP (Tensor Parallelism)**: Splits tensor computations (e.g., matrix multiplications, attention heads) across devices for intra-layer speedups.
61+
- **EP (Expert Parallelism)**: Distributes MoE experts across devices, supporting dropless routing and load balancing to ensure efficient utilization.
62+
- **DP (Data Parallelism)**: Replicates the model across devices while splitting the input data batches.
63+
- **PP (Pipeline Parallelism)**: Splits layers across device stages to support extremely large models by managing inter-stage communication.
64+
- **CP (Context Parallelism)**: Splits sequence tokens across devices, complementing tensor parallelism for long-context workloads.
65+
- **Hybrid Parallelism**: Allows for flexible combinations of FSDP, TP, EP, DP, PP, and CP to maximize hardware utilization based on model size and topology.
6866

6967
## Performance characteristics
7068

7169
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.
7270

73-
* **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
74-
* **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
75-
* **MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance.
76-
* **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
77-
* **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
78-
71+
- **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
72+
- **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
73+
- **MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance.
74+
- **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
75+
- **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
7976

8077
## References
8178

82-
* [MaxText Repo](https://github.com/AI-Hypercomputer/maxtext)
79+
- [MaxText Repo](https://github.com/AI-Hypercomputer/maxtext)
80+
81+
- **Model Implementation Guides & Source Code:**
8382

84-
* **Model Implementation Guides & Source Code:**
85-
* **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama4.py)
86-
* **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma3.py)
87-
* **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mistral.py)
88-
* **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/deepseek.py)
89-
* **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/qwen3.py)
90-
* **GPT-OSS**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/gpt_oss/run_gpt_oss.md) | [GPT-OSS Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gpt_oss.py)
83+
- **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama4.py)
84+
- **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma3.py)
85+
- **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mistral.py)
86+
- **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/deepseek.py)
87+
- **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/qwen3.py)
88+
- **GPT-OSS**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md) | [GPT-OSS Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gpt_oss.py)
9189

90+
- **Technical Explanations:**
9291

93-
* **Technical Explanations:**
94-
* [Parallelism & Sharding](../../guides/optimization/sharding.md)
95-
* [Quantization Documentation](../core_concepts/quantization.md)
96-
* [AOT Compilation Instructions](../../guides/monitoring_and_debugging/features_and_diagnostics.md#ahead-of-time-compilation-aot)
92+
- [Parallelism & Sharding](../../guides/optimization/sharding.md)
93+
- [Quantization Documentation](../core_concepts/quantization.md)
94+
- [AOT Compilation Instructions](../../guides/monitoring_and_debugging/features_and_diagnostics.md#ahead-of-time-compilation-aot)

docs/tutorials/posttraining/full_finetuning.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ python3 -m MaxText.train \
113113
steps=10 per_device_batch_size=1
114114
```
115115

116-
You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu).
116+
You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu).
117117
These scripts can provide a reference point for various scripts.
118118

119119
## Parameters to achieve high MFU

src/MaxText/experimental/rl/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ This directory contains code and documentation for **GRPO**, a reinforcement lea
3030

3131
## Running GRPO
3232

33-
This repository includes a shell script, `end_to_end/tpu/test_grpo.sh`, that demonstrates how to run GRPO on a v5p-256 cluster.
33+
This repository includes a shell script, `tests/end_to_end/tpu/test_grpo.sh`, that demonstrates how to run GRPO on a v5p-256 cluster.
3434

3535
**How it works:**
3636

@@ -50,4 +50,4 @@ DEVICES_PER_SAMPLER=8 \
5050
TRAINING_PER_DEVICE_BATCH_SIZE=1 \
5151
INFERENCE_PER_DEVICE_BATCH_SIZE=8 \
5252
STEPS=20 \
53-
bash end_to_end/tpu/test_grpo.sh
53+
bash tests/end_to_end/tpu/test_grpo.sh

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
print_ram_usage,
9494
print_peak_memory,
9595
)
96-
from maxtext.inference_utils import str2bool
96+
from maxtext.inference.inference_utils import str2bool
9797

9898

9999
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log

src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from MaxText import max_logging
4444
from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt
45-
from maxtext.inference_utils import str2bool
45+
from maxtext.inference.inference_utils import str2bool
4646

4747
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
4848

src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from MaxText.utils.ckpt_scripts import convert_deepseek_family_ckpt as ds_ckpt
4040
from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt
4141
from MaxText import max_logging
42-
from maxtext.inference_utils import str2bool
42+
from maxtext.inference.inference_utils import str2bool
4343
from safetensors import safe_open
4444

4545
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log

src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
4242
from MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt import MODEL_PARAMS_DICT, _hf_to_maxtext_mapping, _pt_to_np
4343
from MaxText.utils.ckpt_conversion.utils.utils import MemoryMonitorTqdm, print_peak_memory
44-
from maxtext.inference_utils import str2bool
44+
from maxtext.inference.inference_utils import str2bool
4545

4646
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
4747

src/MaxText/utils/ckpt_scripts/convert_gpt_oss_unscanned_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from MaxText import max_logging
4040
from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
41-
from maxtext.inference_utils import str2bool
41+
from maxtext.inference.inference_utils import str2bool
4242

4343
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
4444

src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from MaxText import max_logging
3636
from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt
37-
from maxtext.inference_utils import str2bool
37+
from maxtext.inference.inference_utils import str2bool
3838

3939
# Static model parameters dictionary
4040
MODEL_PARAMS_DICT = {

0 commit comments

Comments
 (0)