You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/reference/models/supported_models_and_architectures.md
+44-46Lines changed: 44 additions & 46 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -8,89 +8,87 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX.
8
8
9
9
**Key capabilities and features**:
10
10
11
-
***Supported Precisions**: FP32, BF16, INT8, and FP8.
12
-
***Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13
-
***Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md).
14
-
***Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
15
-
***Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction.
16
-
***Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
17
-
***Flexible Remat Policy**: Provides fine-grained control over memory-compute trade-offs. Users can select pre-defined policies (like 'full' or 'minimal') or set the policy to **'custom'**.
18
-
11
+
-**Supported Precisions**: FP32, BF16, INT8, and FP8.
12
+
-**Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13
+
-**Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md).
14
+
-**Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
15
+
-**Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction.
16
+
-**Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
17
+
-**Flexible Remat Policy**: Provides fine-grained control over memory-compute trade-offs. Users can select pre-defined policies (like 'full' or 'minimal') or set the policy to **'custom'**.
19
18
20
19
## Supported model families
21
20
22
21
> _**Note on GPU Coverage**: Support and tested configurations for NVIDIA GPUs can vary by model family. Please see the specific model guides for details._
23
22
24
23
**Primary Platforms**: All model families listed below target **TPU** and **NVIDIA GPUs**.
MaxText supports a wide range of parallelism strategies for scaling training and inference across TPUs and GPUs:
60
58
61
-
***FSDP (Fully Sharded Data Parallel)**: Reduces memory footprint by sharding parameters and optimizer states across devices.
62
-
***TP (Tensor Parallelism)**: Splits tensor computations (e.g., matrix multiplications, attention heads) across devices for intra-layer speedups.
63
-
***EP (Expert Parallelism)**: Distributes MoE experts across devices, supporting dropless routing and load balancing to ensure efficient utilization.
64
-
***DP (Data Parallelism)**: Replicates the model across devices while splitting the input data batches.
65
-
***PP (Pipeline Parallelism)**: Splits layers across device stages to support extremely large models by managing inter-stage communication.
66
-
***CP (Context Parallelism)**: Splits sequence tokens across devices, complementing tensor parallelism for long-context workloads.
67
-
***Hybrid Parallelism**: Allows for flexible combinations of FSDP, TP, EP, DP, PP, and CP to maximize hardware utilization based on model size and topology.
59
+
-**FSDP (Fully Sharded Data Parallel)**: Reduces memory footprint by sharding parameters and optimizer states across devices.
60
+
-**TP (Tensor Parallelism)**: Splits tensor computations (e.g., matrix multiplications, attention heads) across devices for intra-layer speedups.
61
+
-**EP (Expert Parallelism)**: Distributes MoE experts across devices, supporting dropless routing and load balancing to ensure efficient utilization.
62
+
-**DP (Data Parallelism)**: Replicates the model across devices while splitting the input data batches.
63
+
-**PP (Pipeline Parallelism)**: Splits layers across device stages to support extremely large models by managing inter-stage communication.
64
+
-**CP (Context Parallelism)**: Splits sequence tokens across devices, complementing tensor parallelism for long-context workloads.
65
+
-**Hybrid Parallelism**: Allows for flexible combinations of FSDP, TP, EP, DP, PP, and CP to maximize hardware utilization based on model size and topology.
68
66
69
67
## Performance characteristics
70
68
71
69
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.
72
70
73
-
***High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
74
-
***Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
75
-
***MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance.
76
-
***Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
77
-
***Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
78
-
71
+
-**High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
72
+
-**Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
73
+
-**MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance.
74
+
-**Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
75
+
-**Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
0 commit comments