-
Notifications
You must be signed in to change notification settings - Fork 493
[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory #3493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -629,6 +629,24 @@ class MoEGeneral(BaseModel): | |||
| False, | ||||
| description="Whether to use Ring of Experts for sparse matmul expert parallelism.", | ||||
| ) | ||||
| ring_paged_stash: bool = Field( | ||||
| False, | ||||
| description=( | ||||
| "Enable paged stashing for ring-of-experts MoE layers. " | ||||
| "Instead of checkpointing GMM activations at worst-case buffer size " | ||||
| "(batch*EP*seq*top_k), compactly stores only the actual routed tokens " | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for the formula, could you align it with maxtext/src/maxtext/layers/moe.py Line 1170 in 16b6848
|
||||
| "in a shared static buffer, reducing host-offload memory ~4x for EP=4. " | ||||
| "See layers/paged_stash.py for details." | ||||
| ), | ||||
| ) | ||||
| ring_paged_stash_safety_margin: float = Field( | ||||
| 1.5, | ||||
| description=( | ||||
| "Safety margin multiplier on the expected-per-layer token count used to " | ||||
| "size each layer's stash chunk (max_chunk = expected * margin). " | ||||
| "1.0 = no slack; 1.5 = tolerate 50%% per-layer imbalance without dropping." | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this strategy introduce dropping? If I understand correctly, no matter how we set margin value, it will iterate till the last chunk, right? If so, shall we avoid dropping word here? |
||||
| ), | ||||
| ) | ||||
| use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") | ||||
| interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") | ||||
| expert_shard_attention_option: Literal["fsdp", "context"] = Field( | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |||
| from flax import nnx | ||||
| import jax | ||||
| from jax import ad_checkpoint as adc | ||||
| from MaxText.layers import paged_stash as ps | ||||
| from jax.experimental import xla_metadata | ||||
| from jax.sharding import NamedSharding, Mesh | ||||
| from jax.sharding import PartitionSpec as P | ||||
|
|
@@ -1305,7 +1306,40 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): | |||
| ) | ||||
| if self.config.mlp_bias: | ||||
| intermediate_output = intermediate_output + wo_bias | ||||
| intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo") | ||||
| if self.config.use_ring_of_experts and self.config.ring_paged_stash: | ||||
| # ------------------------------------------------------------------- | ||||
| # Paged stashing: instead of checkpointing intermediate_output at | ||||
| # worst-case size (batch*EP*seq*top_k, hidden), store only the actual | ||||
| # routed tokens in a compact shared buffer passed through the scan | ||||
| # carry. This reduces host-offload memory ~4x for EP=4. | ||||
| # | ||||
| # stash_buf and write_ptr are threaded through the decoder scan carry. | ||||
| # TODO: wire (stash_buf, write_ptr, layer_sizes) into the decoder | ||||
| # __call__ signatures in decoders.py / deepseek.py. For now this | ||||
| # block shows the moe.py side of the integration. | ||||
| # | ||||
| # See layers/paged_stash.py for full documentation. | ||||
| # ------------------------------------------------------------------- | ||||
| actual_tokens = jnp.sum(group_sizes) | ||||
| expected = ps.expected_tokens_per_layer( | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We found this math will give smaller range: maxtext/src/maxtext/layers/moe.py Line 1170 in 16b6848
|
||||
| batch_size, num_expert_parallelism, sequence_length, | ||||
| self.config.num_experts_per_tok, | ||||
| ) | ||||
| max_chunk = int(expected * self.config.ring_paged_stash_safety_margin) | ||||
| stash_fn, restore_fn = ps.make_stash_fns(max_chunk, self.config.emb_dim) | ||||
|
|
||||
| # Stash: compact intermediate_output into the shared buffer. | ||||
| # stash_buf / write_ptr come from the scan carry (see TODO above). | ||||
| stash_buf, write_ptr = stash_fn( | ||||
| stash_buf, write_ptr, intermediate_output, actual_tokens | ||||
| ) | ||||
| # intermediate_output is no longer needed; the backward will restore it. | ||||
| intermediate_output = restore_fn( | ||||
| stash_buf, write_ptr - actual_tokens, actual_tokens, | ||||
| intermediate_output.shape[0], | ||||
| ) | ||||
| else: | ||||
| intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo") | ||||
|
|
||||
| if self.config.use_ring_of_experts: | ||||
| # Set the outputs of tokens which were not processed to 0. | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,222 @@ | ||||
| # Copyright 2025 Google LLC | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: 2026 |
||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|
|
||||
| """Paged stashing for ring-of-experts MoE activation memory reduction. | ||||
|
|
||||
| Background | ||||
| ---------- | ||||
| With ring_of_experts=True and expert_parallelism=EP, the token buffer fed into | ||||
| each MoE GMM is inflated to: | ||||
|
|
||||
| worst_case = batch * EP * seq * top_k (e.g. 262,144 for bs=2, EP=4, | ||||
| seq=4096, top_k=8) | ||||
|
|
||||
| However, with load-balance loss the *actual* tokens routed to each EP shard is | ||||
| roughly: | ||||
|
|
||||
| expected = worst_case / EP (e.g. 65,536) | ||||
|
|
||||
| Naively checkpointing the GMM outputs (moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo) at | ||||
| worst-case size balloons host-offload memory to ~210 GB for a 60-layer model. | ||||
|
|
||||
| Idea (inspired by Megatron-LM PR #2690 "Paged Stashing") | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we include a link? |
||||
| --------------------------------------------------------- | ||||
| Decouple the *compute* buffer (worst_case tokens, needed by XLA/CUDA graphs) | ||||
| from the *storage* buffer (actual tokens, ~4x smaller with EP=4). | ||||
|
|
||||
| Forward: | ||||
| 1. Run GMM on full worst_case buffer → output (worst_case, hidden) | ||||
| 2. Stash: copy only the actual tokens into a compact shared buffer at a | ||||
| dynamically-tracked offset. The rest of the worst_case buffer is freed. | ||||
|
|
||||
| Backward: | ||||
| 1. Restore: scatter the compact tokens back into a zero-padded worst_case | ||||
| buffer at the correct positions. | ||||
| 2. Run GMM backward on the restored buffer. | ||||
|
|
||||
| JAX implementation | ||||
| ------------------ | ||||
| XLA requires static tensor shapes but supports *dynamic start indices* in | ||||
| lax.dynamic_update_slice / lax.dynamic_slice. We exploit this as follows: | ||||
|
|
||||
| - The shared stash buffer has a STATIC shape: | ||||
| (TOTAL_CAPACITY, hidden) | ||||
| where TOTAL_CAPACITY = num_layers * expected_per_layer + MAX_PER_LAYER. | ||||
|
|
||||
| - Each layer writes a fixed-size chunk (MAX_PER_LAYER rows) at a *dynamic* | ||||
| offset tracked in the scan carry, then advances the offset by the layer's | ||||
| actual token count (a dynamic scalar). Subsequent layers thus pack their | ||||
| tokens immediately after the previous layer's actual data, with no gaps. | ||||
|
|
||||
| - On the backward scan the process is reversed using the stored per-layer | ||||
| offsets and sizes. | ||||
|
|
||||
| Memory comparison (bs=2, EP=4, seq=4096, top_k=8, hidden=7168, 60 MoE layers): | ||||
|
|
||||
| Strategy | tokens/layer | Host memory (moe_mlpwo only) | ||||
| --------------------------|---------------|----------------------------- | ||||
| No cap (baseline) | 262,144 | ~210 GB ❌ | ||||
| 50% static cap | 131,072 | ~105 GB ✅ | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We didn't see the trend of dropping strategy due to model quality. Shall we remove this option if Paged stash is obvious better here? |
||||
| Paged stash (this PR) | ~65,536 avg | ~52 GB ✅ (2x better) | ||||
|
|
||||
| The paged stash approach is strictly superior to a static cap because: | ||||
| - No token dropping: GMM still runs on all worst_case tokens. | ||||
| - Transient per-layer imbalance is absorbed by the shared budget instead of | ||||
| dropping tokens in that layer. | ||||
| - Memory tracks *actual* load rather than a fixed conservative ceiling. | ||||
|
|
||||
| Scan carry integration | ||||
| ---------------------- | ||||
| The stash buffer and write pointer are threaded through the decoder scan carry: | ||||
|
|
||||
| carry = (residual, stash_buf, write_ptr, layer_sizes) | ||||
|
|
||||
| `layer_sizes` is a (num_layers,) integer array recording each layer's actual | ||||
| token count; it is needed by the backward scan to compute read offsets. | ||||
|
|
||||
| TODO: The decoder __call__ signatures in decoders.py / deepseek.py need to be | ||||
| updated to thread (stash_buf, write_ptr, layer_sizes) through the scan carry. | ||||
| This PR implements the core primitives and MoE-layer integration; the decoder | ||||
| wiring is left as a follow-up. | ||||
| """ | ||||
|
|
||||
| import functools | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| from jax import lax | ||||
|
|
||||
|
|
||||
| # --------------------------------------------------------------------------- | ||||
| # Core stash / restore primitives | ||||
| # --------------------------------------------------------------------------- | ||||
|
|
||||
| def make_stash_fns(max_chunk: int, hidden: int): | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we write a unit test in moe_test.py to ensure the correctness with expert sharding? One example: maxtext/tests/unit/moe_test.py Line 542 in 16b6848
|
||||
| """Return (stash_fn, restore_fn) for a given static chunk size and hidden dim. | ||||
|
|
||||
| Args: | ||||
| max_chunk: Static maximum number of tokens written per layer. Must be >= | ||||
| the maximum actual token count that will ever be encountered. | ||||
| Set to expected_per_layer * safety_margin (e.g. 1.5x). | ||||
| hidden: Hidden dimension of the tensors to stash. | ||||
|
|
||||
| Returns: | ||||
| stash_fn: (buf, write_ptr, x, actual_tokens) -> (new_buf, new_write_ptr) | ||||
| restore_fn: (buf, read_ptr, actual_tokens, full_size) -> (x_full,) | ||||
| """ | ||||
|
|
||||
| @jax.custom_vjp | ||||
| def stash_fn(buf, write_ptr, x, actual_tokens): | ||||
| """Write the first `actual_tokens` rows of x into buf at write_ptr. | ||||
|
|
||||
| Rows beyond actual_tokens are written as zeros (masked), so buf remains | ||||
| well-defined for any downstream reads. The write pointer advances by | ||||
| actual_tokens (not max_chunk), so successive layers pack tightly. | ||||
| """ | ||||
| mask = jnp.arange(max_chunk) < actual_tokens # (max_chunk,) | ||||
| chunk = jnp.where(mask[:, None], x[:max_chunk], 0.0) # (max_chunk, hidden) | ||||
| new_buf = lax.dynamic_update_slice(buf, chunk, (write_ptr, 0)) | ||||
| new_write_ptr = write_ptr + actual_tokens | ||||
| return new_buf, new_write_ptr | ||||
|
|
||||
| def stash_fn_fwd(buf, write_ptr, x, actual_tokens): | ||||
| new_buf, new_write_ptr = stash_fn(buf, write_ptr, x, actual_tokens) | ||||
| # Save write_ptr and actual_tokens for the backward pass. | ||||
| return (new_buf, new_write_ptr), (write_ptr, actual_tokens) | ||||
|
|
||||
| def stash_fn_bwd(res, g): | ||||
| write_ptr, actual_tokens = res | ||||
| d_new_buf, _ = g | ||||
| # Gradient w.r.t. x: read back the chunk we wrote (the gradient flows | ||||
| # through the buffer slice we updated). | ||||
| d_chunk = lax.dynamic_slice(d_new_buf, (write_ptr, 0), (max_chunk, hidden)) | ||||
| # Zero out positions beyond actual_tokens (they were masked to 0 in fwd). | ||||
| mask = jnp.arange(max_chunk) < actual_tokens | ||||
| d_x_chunk = jnp.where(mask[:, None], d_chunk, 0.0) | ||||
| # Pad d_x back to the full worst_case size expected by the caller. | ||||
| # The caller's x has shape (full_size, hidden); we only touched [:max_chunk]. | ||||
| # Positions [max_chunk:] had zero gradient contribution. | ||||
| d_x = jnp.zeros_like(d_chunk) # will be broadcast by caller if needed | ||||
| d_x = d_x.at[:max_chunk].set(d_x_chunk) | ||||
| # Gradient w.r.t. buf: the updated slice is consumed, so pass d_new_buf | ||||
| # back with the written region zeroed out (it has been "consumed"). | ||||
| zero_chunk = jnp.zeros((max_chunk, hidden), dtype=d_new_buf.dtype) | ||||
| d_buf = lax.dynamic_update_slice(d_new_buf, zero_chunk, (write_ptr, 0)) | ||||
| return d_buf, 0, d_x, 0 # d_buf, d_write_ptr, d_x, d_actual_tokens | ||||
|
|
||||
| stash_fn.defvjp(stash_fn_fwd, stash_fn_bwd) | ||||
|
|
||||
| @jax.custom_vjp | ||||
| def restore_fn(buf, read_ptr, actual_tokens, full_size): | ||||
| """Read actual_tokens rows from buf[read_ptr:] and scatter into (full_size, hidden). | ||||
|
|
||||
| The caller's sorted token buffer has shape (full_size, hidden). Only the | ||||
| first actual_tokens positions are non-zero (the rest were masked in the | ||||
| forward permute step). restore_fn reconstructs this layout. | ||||
| """ | ||||
| compact = lax.dynamic_slice(buf, (read_ptr, 0), (max_chunk, hidden)) | ||||
| # Scatter compact tokens into positions [0:actual_tokens], zero elsewhere. | ||||
| indices = jnp.arange(full_size) | ||||
| mask = indices < actual_tokens | ||||
| # Clamp indices to avoid out-of-bounds; masked positions get overwritten anyway. | ||||
| safe_idx = jnp.minimum(indices, max_chunk - 1) | ||||
| x_full = jnp.where(mask[:, None], compact[safe_idx], 0.0) | ||||
| return x_full | ||||
|
|
||||
| def restore_fn_fwd(buf, read_ptr, actual_tokens, full_size): | ||||
| x_full = restore_fn(buf, read_ptr, actual_tokens, full_size) | ||||
| return x_full, (read_ptr, actual_tokens) | ||||
|
|
||||
| def restore_fn_bwd(res, g_x_full): | ||||
| read_ptr, actual_tokens = res | ||||
| # Gradient flows back into the compact slice: gather from g_x_full. | ||||
| indices = jnp.arange(max_chunk) | ||||
| mask = indices < actual_tokens | ||||
| d_compact = jnp.where(mask[:, None], g_x_full[:max_chunk], 0.0) | ||||
| # Write d_compact into an all-zero d_buf at read_ptr. | ||||
| d_buf = jnp.zeros((lax.dynamic_slice.out_aval,), dtype=g_x_full.dtype) # placeholder | ||||
| # NOTE: caller must accumulate into the shared d_buf. | ||||
| # Return as a zeros buffer with the slice filled; the scan will accumulate. | ||||
| d_buf = lax.dynamic_update_slice( | ||||
| jnp.zeros_like(g_x_full[:1].repeat(max_chunk, axis=0)), # shape hint | ||||
| d_compact, (0, 0) | ||||
| ) | ||||
| return d_buf, 0, 0, 0 | ||||
|
|
||||
| restore_fn.defvjp(restore_fn_fwd, restore_fn_bwd) | ||||
|
|
||||
| return stash_fn, restore_fn | ||||
|
|
||||
|
|
||||
| # --------------------------------------------------------------------------- | ||||
| # Buffer sizing helpers | ||||
| # --------------------------------------------------------------------------- | ||||
|
|
||||
| def stash_buffer_size(num_moe_layers: int, expected_per_layer: int, max_chunk: int) -> int: | ||||
| """Total rows in the shared stash buffer. | ||||
|
|
||||
| Sized for the expected cumulative token count plus one extra max_chunk | ||||
| of headroom for the last layer. | ||||
| """ | ||||
| return num_moe_layers * expected_per_layer + max_chunk | ||||
|
|
||||
|
|
||||
| def expected_tokens_per_layer( | ||||
| batch_size: int, | ||||
| num_expert_parallelism: int, | ||||
| sequence_length: int, | ||||
| num_experts_per_tok: int, | ||||
| ) -> int: | ||||
| """Expected token count per EP shard per layer with uniform routing.""" | ||||
| worst_case = batch_size * num_expert_parallelism * sequence_length * num_experts_per_tok | ||||
| return worst_case // num_expert_parallelism | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also add those into base.yml & this doc for alignment? Similar comments for other occurrences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we call it ring_of_experts_paged_stash to distinguish ring attetnion?