Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,24 @@ class MoEGeneral(BaseModel):
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
ring_paged_stash: bool = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add those into base.yml & this doc for alignment? Similar comments for other occurrences.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we call it ring_of_experts_paged_stash to distinguish ring attetnion?

False,
description=(
"Enable paged stashing for ring-of-experts MoE layers. "
"Instead of checkpointing GMM activations at worst-case buffer size "
"(batch*EP*seq*top_k), compactly stores only the actual routed tokens "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the formula, could you align it with

buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
?

"in a shared static buffer, reducing host-offload memory ~4x for EP=4. "
"See layers/paged_stash.py for details."
),
)
ring_paged_stash_safety_margin: float = Field(
1.5,
description=(
"Safety margin multiplier on the expected-per-layer token count used to "
"size each layer's stash chunk (max_chunk = expected * margin). "
"1.0 = no slack; 1.5 = tolerate 50%% per-layer imbalance without dropping."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this strategy introduce dropping? If I understand correctly, no matter how we set margin value, it will iterate till the last chunk, right? If so, shall we avoid dropping word here?

),
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
Expand Down
36 changes: 35 additions & 1 deletion src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flax import nnx
import jax
from jax import ad_checkpoint as adc
from MaxText.layers import paged_stash as ps
from jax.experimental import xla_metadata
from jax.sharding import NamedSharding, Mesh
from jax.sharding import PartitionSpec as P
Expand Down Expand Up @@ -1305,7 +1306,40 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
)
if self.config.mlp_bias:
intermediate_output = intermediate_output + wo_bias
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")
if self.config.use_ring_of_experts and self.config.ring_paged_stash:
# -------------------------------------------------------------------
# Paged stashing: instead of checkpointing intermediate_output at
# worst-case size (batch*EP*seq*top_k, hidden), store only the actual
# routed tokens in a compact shared buffer passed through the scan
# carry. This reduces host-offload memory ~4x for EP=4.
#
# stash_buf and write_ptr are threaded through the decoder scan carry.
# TODO: wire (stash_buf, write_ptr, layer_sizes) into the decoder
# __call__ signatures in decoders.py / deepseek.py. For now this
# block shows the moe.py side of the integration.
#
# See layers/paged_stash.py for full documentation.
# -------------------------------------------------------------------
actual_tokens = jnp.sum(group_sizes)
expected = ps.expected_tokens_per_layer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We found this math will give smaller range:

buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
?

batch_size, num_expert_parallelism, sequence_length,
self.config.num_experts_per_tok,
)
max_chunk = int(expected * self.config.ring_paged_stash_safety_margin)
stash_fn, restore_fn = ps.make_stash_fns(max_chunk, self.config.emb_dim)

# Stash: compact intermediate_output into the shared buffer.
# stash_buf / write_ptr come from the scan carry (see TODO above).
stash_buf, write_ptr = stash_fn(
stash_buf, write_ptr, intermediate_output, actual_tokens
)
# intermediate_output is no longer needed; the backward will restore it.
intermediate_output = restore_fn(
stash_buf, write_ptr - actual_tokens, actual_tokens,
intermediate_output.shape[0],
)
else:
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")

if self.config.use_ring_of_experts:
# Set the outputs of tokens which were not processed to 0.
Expand Down
222 changes: 222 additions & 0 deletions src/maxtext/layers/paged_stash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2025 Google LLC
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2026

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Paged stashing for ring-of-experts MoE activation memory reduction.

Background
----------
With ring_of_experts=True and expert_parallelism=EP, the token buffer fed into
each MoE GMM is inflated to:

worst_case = batch * EP * seq * top_k (e.g. 262,144 for bs=2, EP=4,
seq=4096, top_k=8)

However, with load-balance loss the *actual* tokens routed to each EP shard is
roughly:

expected = worst_case / EP (e.g. 65,536)

Naively checkpointing the GMM outputs (moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo) at
worst-case size balloons host-offload memory to ~210 GB for a 60-layer model.

Idea (inspired by Megatron-LM PR #2690 "Paged Stashing")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we include a link?

---------------------------------------------------------
Decouple the *compute* buffer (worst_case tokens, needed by XLA/CUDA graphs)
from the *storage* buffer (actual tokens, ~4x smaller with EP=4).

Forward:
1. Run GMM on full worst_case buffer → output (worst_case, hidden)
2. Stash: copy only the actual tokens into a compact shared buffer at a
dynamically-tracked offset. The rest of the worst_case buffer is freed.

Backward:
1. Restore: scatter the compact tokens back into a zero-padded worst_case
buffer at the correct positions.
2. Run GMM backward on the restored buffer.

JAX implementation
------------------
XLA requires static tensor shapes but supports *dynamic start indices* in
lax.dynamic_update_slice / lax.dynamic_slice. We exploit this as follows:

- The shared stash buffer has a STATIC shape:
(TOTAL_CAPACITY, hidden)
where TOTAL_CAPACITY = num_layers * expected_per_layer + MAX_PER_LAYER.

- Each layer writes a fixed-size chunk (MAX_PER_LAYER rows) at a *dynamic*
offset tracked in the scan carry, then advances the offset by the layer's
actual token count (a dynamic scalar). Subsequent layers thus pack their
tokens immediately after the previous layer's actual data, with no gaps.

- On the backward scan the process is reversed using the stored per-layer
offsets and sizes.

Memory comparison (bs=2, EP=4, seq=4096, top_k=8, hidden=7168, 60 MoE layers):

Strategy | tokens/layer | Host memory (moe_mlpwo only)
--------------------------|---------------|-----------------------------
No cap (baseline) | 262,144 | ~210 GB ❌
50% static cap | 131,072 | ~105 GB ✅
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't see the trend of dropping strategy due to model quality. Shall we remove this option if Paged stash is obvious better here?

Paged stash (this PR) | ~65,536 avg | ~52 GB ✅ (2x better)

The paged stash approach is strictly superior to a static cap because:
- No token dropping: GMM still runs on all worst_case tokens.
- Transient per-layer imbalance is absorbed by the shared budget instead of
dropping tokens in that layer.
- Memory tracks *actual* load rather than a fixed conservative ceiling.

Scan carry integration
----------------------
The stash buffer and write pointer are threaded through the decoder scan carry:

carry = (residual, stash_buf, write_ptr, layer_sizes)

`layer_sizes` is a (num_layers,) integer array recording each layer's actual
token count; it is needed by the backward scan to compute read offsets.

TODO: The decoder __call__ signatures in decoders.py / deepseek.py need to be
updated to thread (stash_buf, write_ptr, layer_sizes) through the scan carry.
This PR implements the core primitives and MoE-layer integration; the decoder
wiring is left as a follow-up.
"""

import functools
import jax
import jax.numpy as jnp
from jax import lax


# ---------------------------------------------------------------------------
# Core stash / restore primitives
# ---------------------------------------------------------------------------

def make_stash_fns(max_chunk: int, hidden: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we write a unit test in moe_test.py to ensure the correctness with expert sharding? One example:

def test_megablox_expert_parallelism(self):

"""Return (stash_fn, restore_fn) for a given static chunk size and hidden dim.

Args:
max_chunk: Static maximum number of tokens written per layer. Must be >=
the maximum actual token count that will ever be encountered.
Set to expected_per_layer * safety_margin (e.g. 1.5x).
hidden: Hidden dimension of the tensors to stash.

Returns:
stash_fn: (buf, write_ptr, x, actual_tokens) -> (new_buf, new_write_ptr)
restore_fn: (buf, read_ptr, actual_tokens, full_size) -> (x_full,)
"""

@jax.custom_vjp
def stash_fn(buf, write_ptr, x, actual_tokens):
"""Write the first `actual_tokens` rows of x into buf at write_ptr.

Rows beyond actual_tokens are written as zeros (masked), so buf remains
well-defined for any downstream reads. The write pointer advances by
actual_tokens (not max_chunk), so successive layers pack tightly.
"""
mask = jnp.arange(max_chunk) < actual_tokens # (max_chunk,)
chunk = jnp.where(mask[:, None], x[:max_chunk], 0.0) # (max_chunk, hidden)
new_buf = lax.dynamic_update_slice(buf, chunk, (write_ptr, 0))
new_write_ptr = write_ptr + actual_tokens
return new_buf, new_write_ptr

def stash_fn_fwd(buf, write_ptr, x, actual_tokens):
new_buf, new_write_ptr = stash_fn(buf, write_ptr, x, actual_tokens)
# Save write_ptr and actual_tokens for the backward pass.
return (new_buf, new_write_ptr), (write_ptr, actual_tokens)

def stash_fn_bwd(res, g):
write_ptr, actual_tokens = res
d_new_buf, _ = g
# Gradient w.r.t. x: read back the chunk we wrote (the gradient flows
# through the buffer slice we updated).
d_chunk = lax.dynamic_slice(d_new_buf, (write_ptr, 0), (max_chunk, hidden))
# Zero out positions beyond actual_tokens (they were masked to 0 in fwd).
mask = jnp.arange(max_chunk) < actual_tokens
d_x_chunk = jnp.where(mask[:, None], d_chunk, 0.0)
# Pad d_x back to the full worst_case size expected by the caller.
# The caller's x has shape (full_size, hidden); we only touched [:max_chunk].
# Positions [max_chunk:] had zero gradient contribution.
d_x = jnp.zeros_like(d_chunk) # will be broadcast by caller if needed
d_x = d_x.at[:max_chunk].set(d_x_chunk)
# Gradient w.r.t. buf: the updated slice is consumed, so pass d_new_buf
# back with the written region zeroed out (it has been "consumed").
zero_chunk = jnp.zeros((max_chunk, hidden), dtype=d_new_buf.dtype)
d_buf = lax.dynamic_update_slice(d_new_buf, zero_chunk, (write_ptr, 0))
return d_buf, 0, d_x, 0 # d_buf, d_write_ptr, d_x, d_actual_tokens

stash_fn.defvjp(stash_fn_fwd, stash_fn_bwd)

@jax.custom_vjp
def restore_fn(buf, read_ptr, actual_tokens, full_size):
"""Read actual_tokens rows from buf[read_ptr:] and scatter into (full_size, hidden).

The caller's sorted token buffer has shape (full_size, hidden). Only the
first actual_tokens positions are non-zero (the rest were masked in the
forward permute step). restore_fn reconstructs this layout.
"""
compact = lax.dynamic_slice(buf, (read_ptr, 0), (max_chunk, hidden))
# Scatter compact tokens into positions [0:actual_tokens], zero elsewhere.
indices = jnp.arange(full_size)
mask = indices < actual_tokens
# Clamp indices to avoid out-of-bounds; masked positions get overwritten anyway.
safe_idx = jnp.minimum(indices, max_chunk - 1)
x_full = jnp.where(mask[:, None], compact[safe_idx], 0.0)
return x_full

def restore_fn_fwd(buf, read_ptr, actual_tokens, full_size):
x_full = restore_fn(buf, read_ptr, actual_tokens, full_size)
return x_full, (read_ptr, actual_tokens)

def restore_fn_bwd(res, g_x_full):
read_ptr, actual_tokens = res
# Gradient flows back into the compact slice: gather from g_x_full.
indices = jnp.arange(max_chunk)
mask = indices < actual_tokens
d_compact = jnp.where(mask[:, None], g_x_full[:max_chunk], 0.0)
# Write d_compact into an all-zero d_buf at read_ptr.
d_buf = jnp.zeros((lax.dynamic_slice.out_aval,), dtype=g_x_full.dtype) # placeholder
# NOTE: caller must accumulate into the shared d_buf.
# Return as a zeros buffer with the slice filled; the scan will accumulate.
d_buf = lax.dynamic_update_slice(
jnp.zeros_like(g_x_full[:1].repeat(max_chunk, axis=0)), # shape hint
d_compact, (0, 0)
)
return d_buf, 0, 0, 0

restore_fn.defvjp(restore_fn_fwd, restore_fn_bwd)

return stash_fn, restore_fn


# ---------------------------------------------------------------------------
# Buffer sizing helpers
# ---------------------------------------------------------------------------

def stash_buffer_size(num_moe_layers: int, expected_per_layer: int, max_chunk: int) -> int:
"""Total rows in the shared stash buffer.

Sized for the expected cumulative token count plus one extra max_chunk
of headroom for the last layer.
"""
return num_moe_layers * expected_per_layer + max_chunk


def expected_tokens_per_layer(
batch_size: int,
num_expert_parallelism: int,
sequence_length: int,
num_experts_per_tok: int,
) -> int:
"""Expected token count per EP shard per layer with uniform routing."""
worst_case = batch_size * num_expert_parallelism * sequence_length * num_experts_per_tok
return worst_case // num_expert_parallelism
Loading