Skip to content

Commit 12dfd5c

Browse files
committed
Update the Ironwood offloading description
1 parent 1d29ce1 commit 12dfd5c

1 file changed

Lines changed: 90 additions & 65 deletions

File tree

docs/guides/optimization/custom_model.md

Lines changed: 90 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ Based on resources like [Language Modeling from Scratch](https://github.com/stan
2828

2929
Dense models
3030

31-
* `mlp_dim / emb_dim`: 2.5-4
32-
* `head_dim * num_query_heads / emb_dim`: 1-2
33-
* `emb_dim / num_decoder_layers`: 100-200
31+
- `mlp_dim / emb_dim`: 2.5-4
32+
- `head_dim * num_query_heads / emb_dim`: 1-2
33+
- `emb_dim / num_decoder_layers`: 100-200
3434

3535
MoE models
3636

37-
* sparsity (`num_experts / num_experts_per_tok`): 4-32
38-
* `moe_mlp_dim / emb_dim`: 0.3-3
37+
- sparsity (`num_experts / num_experts_per_tok`): 4-32
38+
- `moe_mlp_dim / emb_dim`: 0.3-3
3939

4040
## Step 2. Consider TPU best practices
4141

@@ -45,8 +45,8 @@ To unlock peak performance on [TPUs](https://cloud.google.com/tpu/docs/system-ar
4545

4646
Therefore, for optimal efficiency:
4747

48-
* Model and MLP Dimensions: Design your model's emb_dim and mlp_dim to be multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
49-
* Self-Attention Head Dimension: Ensure your attention head_dim are also multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
48+
- Model and MLP Dimensions: Design your model's emb_dim and mlp_dim to be multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
49+
- Self-Attention Head Dimension: Ensure your attention head_dim are also multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
5050

5151
Generally, larger multiples are more efficient. If achieving these specific multiples isn't possible, prioritize dimensions to a multiple of either 8 or 128 to help the XLA compiler optimize memory and computation.
5252

@@ -59,63 +59,68 @@ Ironwood is engineered for cutting-edge, large-scale AI model training and infer
5959
We have published optimized recipes for models like DeepSeek v3, GPT-OSS, Qwen3, and Llama3 on Ironwood, covering both BF16 and FP8 precision, available in this [guide](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/ironwood).
6060

6161
Key strategies to maximize performance on Ironwood include:
62-
* Adopt FP8 Precision: Ironwood delivers 2x throughput with FP8 compared to BF16. Design models to use mixed-precision training, employing FP8 for weights and activations where possible to maximize computational speed.
63-
* Offload to SparseCores: Ironwood's enhanced SparseCores are crucial for efficiency. Offloading collective communication and data management to keep TensorCores focused on compute.
64-
* Leverage the dual-chiplet architecture: Each Ironwood chip contains two TensorCores with an ultra-fast interconnect (die-to-die, 6x faster than 1D ICI link).
62+
63+
- Adopt FP8 Precision: Ironwood delivers 2x throughput with FP8 compared to BF16. Design models to use mixed-precision training, employing FP8 for weights and activations where possible to maximize computational speed.
64+
- Offload to SparseCores: Ironwood's enhanced SparseCores are crucial for efficiency. Offloading collective communication and data management to keep TensorCores focused on compute.
65+
- Leverage the dual-chiplet architecture: Each Ironwood chip contains two TensorCores with an ultra-fast interconnect (die-to-die, 6x faster than 1D ICI link).
6566

6667
Given Ironwood's high compute power, communication bandwidth can easily become the limiting factor. To address this:
67-
* Enable SparseCore offloading for collectives: By setting the appropriate [XLA flags](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/benchmarks/xla_flags_library.py#L70-L116), you can offload collective operations (like All-Reduce, All-Gather, etc.) to the SparseCores. These operations then run in parallel with the TensorCore computations, effectively hiding communication latency and improving Model Flop Utilization (MFU).
68-
* Optimize sharding strategies: Align your model distribution with the hardware topology. Choose sharding strategies (e.g., data, tensor, pipeline parallelism) that minimize data transfer over the ICI and maximize the overlap between computation and communication.
68+
69+
- Leverage SparseCore offloading: By default, collective operations (like All-Reduce, All-Gather, etc.) are offloaded to SparseCore, allowing them to run in parallel with TensorCore computations. This effectively hides communication latency and improving Model Flop Utilization (MFU). Also, you can maximize throughput tuning those [XLA flags](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/benchmarks/xla_flags_library.py#L70-L116).
70+
- Optimize sharding strategies: Align your model distribution with the hardware topology. Choose sharding strategies (e.g., data, tensor, pipeline parallelism) that minimize data transfer over the ICI and maximize the overlap between computation and communication.
6971

7072
### Performance configs
7173

7274
Use these general runtime configurations to improve your model's performance.
7375

74-
* **Multi-Head Attention (MHA)**. If you are using MHA, we recommend to set `fused_qkv=True` to fuse the query, key, and value computations into a single, more efficient operation.
76+
- **Multi-Head Attention (MHA)**. If you are using MHA, we recommend to set `fused_qkv=True` to fuse the query, key, and value computations into a single, more efficient operation.
7577

76-
* **Flash Attention**. Use the largest possible block size to maximize throughput.
78+
- **Flash Attention**. Use the largest possible block size to maximize throughput.
7779

78-
* **Memory usage**. To free up memory with large models, use custom remat policy to offload layer activations (including inputs, attention, and MLP blocks) to the host CPU.
80+
- **Memory usage**. To free up memory with large models, use custom remat policy to offload layer activations (including inputs, attention, and MLP blocks) to the host CPU.
7981

80-
* **Compiler flags**. XLA is the backend compiler for TPUs. Many critical performance settings can be controlled directly through XLA flags. We suggest beginning with the proven flags we have tested and provided [here](https://github.com/AI-Hypercomputer/maxtext/blob/b53bf3bef6b54b1d4939a4b700bc11fe149d1128/benchmarks/xla_flags_library.py).
82+
- **Compiler flags**. XLA is the backend compiler for TPUs. Many critical performance settings can be controlled directly through XLA flags. We suggest beginning with the proven flags we have tested and provided [here](https://github.com/AI-Hypercomputer/maxtext/blob/b53bf3bef6b54b1d4939a4b700bc11fe149d1128/benchmarks/xla_flags_library.py).
8183

82-
* **Benchmark**. For consistent speed tests, set `reuse_example_batch=1` to repeatedly use the same data batch, isolating computation speed from data loading. Or use on-the-fly generated data by setting `dataset_type=synthetic`.
84+
- **Benchmark**. For consistent speed tests, set `reuse_example_batch=1` to repeatedly use the same data batch, isolating computation speed from data loading. Or use on-the-fly generated data by setting `dataset_type=synthetic`.
8385

8486
## Step 3. Choose efficient sharding strategies using Roofline Analysis
8587

8688
To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for [v5p](https://docs.cloud.google.com/tpu/docs/v5p), [Trillium](https://docs.cloud.google.com/tpu/docs/v6e), and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/).
8789

88-
| TPU Type | ICI Arithmetic Intensity |
90+
| TPU Type | ICI Arithmetic Intensity |
8991
|---|---|
90-
| v5p | 2550 for 1D-ICI |
92+
| v5p | 2550 for 1D-ICI |
9193
| Trillium | 5100 for 1D-ICI (1D with wrapound or 2D without wraparound) <br> 2550 for 2D-ICI (2D with wraparound on both dimensions), particularly for v6e-256 |
9294
| Ironwood | 12800 for 1D-ICI|
9395

9496
### Fully Sharded Data Parallelism (FSDP)
9597

9698
#### Pure FSDP
9799

98-
For pure FSDP to be effective, it must have enough memory to hold both a large data batch and a full, single layer of weights at the same time.
100+
For pure FSDP to be effective, it must have enough memory to hold both a large data batch and a full, single layer of weights at the same time.
99101

100102
FSPD AI: `global batch / sparsity` (`sparsity = num_experts / num_experts_per_tok`).
101103

102104
**Example with a sparsity of 16**:
103-
* `global batch / sparsity > hardware AI`
105+
106+
- `global batch / sparsity > hardware AI`
104107

105108
v5p:
106-
* `global batch / 16 > 2550`
107-
* `global batch > 40k` (in tokens)
109+
110+
- `global batch / 16 > 2550`
111+
- `global batch > 40k` (in tokens)
108112

109113
Trillium:
110-
* `global batch / 16 > 2550` (16x16 with wraparound)
111-
* `global batch > 40k` (in tokens)
114+
115+
- `global batch / 16 > 2550` (16x16 with wraparound)
116+
- `global batch > 40k` (in tokens)
112117

113118
We also need a single layer of weights to fit into memory which can be an issue for medium/large MoE models, e.g. DeepSeek has roughly 10B params per layer, which corresponds to 40GiB of bf16 weights and gradients, which will not fit into Trillium’s 32GiB of HBM. So the use of pure FSDP on Trillium is feasible for models with layers not exceeding roughly 5B parameters. For these larger models need Expert or Tensor Parallelism.
114119

115120
Ironwood:
116-
* `global batch / 16 > 12800`
117-
* `global batch > 205k` (in tokens)
118121

122+
- `global batch / 16 > 12800`
123+
- `global batch > 205k` (in tokens)
119124

120125
#### Mix FSDP
121126

@@ -124,19 +129,23 @@ For sparse models, large models, or when scaling to a large number of chips FSDP
124129
The same AI as derived in the Pure FSDP section above still hold, we need `global batch / sparsity * FSDP > hardware AI` which is equivalently to `per device batch (pdb) / sparsity * TP * EP * PP > hardware AI`.
125130

126131
**Example with EP=16, FSDP=16, and sparsity=32**:
127-
* `pdb * EP / sparsity > hardware AI`
132+
133+
- `pdb * EP / sparsity > hardware AI`
128134

129135
v5p:
130-
* `pdb * 16 / 32 > 2550`
131-
* `pdb > 2550 * 32 / 16 = 5k` (in tokens)
136+
137+
- `pdb * 16 / 32 > 2550`
138+
- `pdb > 2550 * 32 / 16 = 5k` (in tokens)
132139

133140
Trillium:
134-
* `pdb * 16 / 32 > 5100`
135-
* `pdb > 5100 * 32 / 16 = 10k` (in tokens)
141+
142+
- `pdb * 16 / 32 > 5100`
143+
- `pdb > 5100 * 32 / 16 = 10k` (in tokens)
136144

137145
Ironwood:
138-
* `pdb * 16 / 32 > 12800`
139-
* `pdb > 12800 * 32 / 16 = 26k` (in tokens)
146+
147+
- `pdb * 16 / 32 > 12800`
148+
- `pdb > 12800 * 32 / 16 = 26k` (in tokens)
140149

141150
We need a per device batch of at least 5k for v5p, 10k for Trillium, and 26k for Ironwood in this case.
142151

@@ -149,16 +158,19 @@ AI of 1D EP on ICI rings `= 4 * mlp_dim / EP`. Communication cost of all-to-all
149158
**Example with EP=4**
150159

151160
v5p:
152-
* `4 * M > 2550 * 4`
153-
* `M > 2.5k`
161+
162+
- `4 * M > 2550 * 4`
163+
- `M > 2.5k`
154164

155165
Trillium:
156-
* `4 * M > 5100 * 4`
157-
* `M > 5k`
166+
167+
- `4 * M > 5100 * 4`
168+
- `M > 5k`
158169

159170
Ironwood:
160-
* `4 * M > 12800 * 4`
161-
* `M > 13k`
171+
172+
- `4 * M > 12800 * 4`
173+
- `M > 13k`
162174

163175
These examples show that to use EP, we need a large enough MLP dimension.
164176

@@ -171,32 +183,39 @@ Tensor parallelism can be used for large dense models or super large sparse mode
171183
AI of TP: M / TP
172184

173185
**Example with TP=4**
174-
* `M / TP > hardware AI`
186+
187+
- `M / TP > hardware AI`
175188

176189
v5p:
177-
* `M / 4 > 2550`
178-
* `M > 10k`
190+
191+
- `M / 4 > 2550`
192+
- `M > 10k`
179193

180194
Trillium:
181-
* `M / 4 > 5100`
182-
* `M > 20k`
195+
196+
- `M / 4 > 5100`
197+
- `M > 20k`
183198

184199
We have seen in practice M should be even larger - ideally 40k+. This is what we use for Llama-405B (M=53k), and was used for a custom sparse 10T model (M=40k, 64 experts). TP=4 corresponds to a custom Trillium mesh, an 8x8 ring of 2x2 subrings (the TP communication operates on the 2x2 ring). This 2x2 ring performs well (near roofline), but the 8x8 rings perform poorly (0.5 x 1 axis). E.g. if we use FSDP=64, TP=4, the FSDP=64 communications will be slower than the hardware ICI roofline, so we prefer to use the full 16 axis when M is large enough.
185200

186201
Ironwood:
187-
* `M / 4 > 12800`
188-
* `M > 51k`
202+
203+
- `M / 4 > 12800`
204+
- `M > 51k`
189205

190206
**Example with TP=16**
191-
* `M / TP > hardware AI`
207+
208+
- `M / TP > hardware AI`
192209

193210
v5p:
194-
* `M / 16 > 2550`
195-
* `M > 41k`
211+
212+
- `M / 16 > 2550`
213+
- `M > 41k`
196214

197215
Trillium:
198-
* `M / 16 > 5100`
199-
* `M > 82k`
216+
217+
- `M / 16 > 5100`
218+
- `M > 82k`
200219

201220
To use TP=16, we need M > 80k (ideally larger, 100k+). We have used this in a custom dense model (900B, M=131k), which performs very well even at 1k per device tokens (scaling to 25k+ with a reasonable global batch).
202221

@@ -207,27 +226,33 @@ Pipeline Parallelism is advantageous when global batch size limits per device ba
207226
AI of PP: 3/2 * layers_per_pipeline_stage * M * num_experts_per_tok
208227

209228
**Example with PP=16, layers_per_pipeline_stage=1, num_experts_per_tok=8**
210-
* `layers_per_pipeline_stage * M * num_experts_per_tok > hardware AI`
229+
230+
- `layers_per_pipeline_stage * M * num_experts_per_tok > hardware AI`
211231

212232
v5p - PP over ICI:
213-
* `3 * M * 8 / 2 > 2550`
214-
* `M > 210`
233+
234+
- `3 * M * 8 / 2 > 2550`
235+
- `M > 210`
215236

216237
v5p - PP over DCN:
217-
* `3 * M * 8 / 2 > 73000`
218-
* `M > 6k`
238+
239+
- `3 * M * 8 / 2 > 73000`
240+
- `M > 6k`
219241

220242
Trillium over ICI:
221-
* `3 * M * 8 / 2 > 5100`
222-
* `M > 420`
243+
244+
- `3 * M * 8 / 2 > 5100`
245+
- `M > 420`
223246

224247
Trillium over DCN:
225-
* `3 * M * 8 / 2 > 73000`
226-
* `M > 6k`
248+
249+
- `3 * M * 8 / 2 > 73000`
250+
- `M > 6k`
227251

228252
Ironwood over ICI:
229-
* `3 * M * 8 / 2 > 12800`
230-
* `M > 1100`
253+
254+
- `3 * M * 8 / 2 > 12800`
255+
- `M > 1100`
231256

232257
It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/sharding.md#pp--fsdpdp) for specific challenges regarding PP + FSDP/DP.
233258

@@ -260,8 +285,8 @@ To use Trillium's 16x16 mesh efficiently for a large dense model, we would like
260285

261286
Our objective was to develop a custom Mixtral-like MoE model capable of high MFU on Trillium TPUs, targeting a 1.5 capacity factor (The **capacity factor** is a multiplier used to determine the processing capacity of each expert. it is used as Expert Capacity = (Tokens in Batch / Number of Experts) * Capacity Factor). We established an initial baseline of 43.1% MFU with a 1.0 capacity factor. Profiling revealed this configuration utilized approximately 20GiB HBM. To better leverage Trillium's 32GiB HBM and avoid potential convergence issues with large global batch sizes during scaling (maintaining a per device batch size of 8k), we made the following architectural adjustments:
262287

263-
* Increased the MLP dimension from 3x to 4x of the model dimension (32,768 : 8,192).
264-
* Increased query heads from 32 to 128 for each layer, while reducing the number of layers from 72 to 56 to preserve overall model size around 700B.
288+
- Increased the MLP dimension from 3x to 4x of the model dimension (32,768 : 8,192).
289+
- Increased query heads from 32 to 128 for each layer, while reducing the number of layers from 72 to 56 to preserve overall model size around 700B.
265290

266291
These changes, without updating sharding strategies, initially yielded nearly 50% MFU. Upon increasing the capacity factor to 1.5 (adding a buffer to allow experts to handle imbalance in token routing), MFU slightly decreased to 38.1% and scaling to 4 pods to get 35.3% MFU, which still exceeded our target of 35%. More detailed configs can be found [here](https://github.com/AI-Hypercomputer/maxtext/blob/3662540ee852d0d8f8333a36c04ddc0f1316ebfb/benchmarks/maxtext_trillium_model_configs.py#L1743) in the repo.
267292

0 commit comments

Comments
 (0)