[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory#3493
Draft
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Draft
[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory#3493abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
Problem ------- With ring_of_experts=True and EP=N, the GMM activation buffer is inflated to: worst_case = batch * EP * seq * top_k e.g. 262,144 tokens for bs=2, EP=4, seq=4096, top_k=8. Checkpointing moe_mlpwo at this size costs ~210 GB host memory for a 60-layer model, exceeding typical 180 GB CPU RAM limits. Approach (inspired by Megatron-LM PR AI-Hypercomputer#2690 "Paged Stashing") ------------------------------------------------------------- Decouple compute buffer size from storage buffer size: Forward: run GMM on full worst_case buffer (no token dropping) stash only actual tokens (~65k) into a shared static buffer at a dynamically-tracked offset Backward: restore compact tokens back into worst_case-shaped buffer run GMM backward correctly on all tokens The shared buffer is sized for the *expected cumulative* token count across all layers (~60 * 65k = ~52 GB) rather than per-layer worst-case (210 GB). Transient per-layer imbalance is absorbed by the shared budget. Implementation -------------- - layers/paged_stash.py: core stash/restore primitives using jax.custom_vjp and lax.dynamic_update_slice (static shape, dynamic start index). - layers/moe.py: ring_of_experts path uses stash_fn/restore_fn when ring_paged_stash=True, falling back to checkpoint_name otherwise. - configs/types.py: ring_paged_stash (bool) and ring_paged_stash_safety_margin (float, default 1.5) config fields. Memory comparison (bs=2, EP=4, seq=4096, top_k=8, hidden=7168, 60 layers): Baseline (no cap): ~210 GB host (moe_mlpwo alone) Static 50% cap: ~105 GB host Paged stash (1.5x): ~78 GB host (safety_margin=1.5 => max_chunk=98k) Paged stash (1.0x): ~52 GB host (safety_margin=1.0 => max_chunk=65k) Status / TODOs -------------- This is a draft for upstream feedback. The following items are incomplete: 1. Decoder scan carry wiring: stash_buf, write_ptr, and layer_sizes must be threaded through the decoder __call__ signatures in decoders.py and deepseek.py. Currently the moe.py code references stash_buf/write_ptr as free variables to show the intended interface. 2. restore_fn backward: the d_buf accumulation across layers needs careful handling in the scan backward -- the current implementation is a sketch. 3. moe_mlpwi_0 / moe_mlpwi_1: the same technique can be applied to the wi GMM outputs, saving an additional ~60 GB each if they are offloaded. 4. Tests: unit tests for stash_fn/restore_fn round-trip correctness and gradient check via jax.test_util.check_grads.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
When
use_ring_of_experts=Truewith expert parallelism EP=4, the ring-all-gather inflates the token buffer on each device tobatch × EP × seq × top_ktokens. For DeepSeek V3 671B withbatch=2, seq=4096, top_k=8, EP=4, this is 262,144 tokens/device.The
moe_mlpwoGMM output has shape(262144, 7168)= 3.5 GB per layer. Withscan_layers=Trueover 60 MoE layers and host offloading, this totals ~210 GB of host memory — exceeding typical limits (180 GB).Solution: Paged Stashing
Inspired by Megatron-LM PR #2690, this PR introduces a paged stash mechanism that decouples the compute buffer (static worst-case shape, required by XLA) from the storage buffer (actual tokens only).
Key idea
total_capacity = num_moe_layers × expected_tokens_per_layer + max_chunk_slacklax.dynamic_update_slice(static size, dynamic start index — XLA-compatible)worst_case / EP = 262144 / 4 = 65536, so the shared buffer is ~4× smaller than per-layer worst-case storageMemory comparison (DeepSeek V3 671B, EP=4, 60 MoE layers)
moe_mlpwoImplementation
New file:
src/MaxText/layers/paged_stash.pymake_stash_fns(max_chunk, hidden)— returns(stash_fn, restore_fn)pair withjax.custom_vjpso gradients flow correctly through the compact/expand operationsstash_buffer_size(num_moe_layers, expected_per_layer, max_chunk)— buffer sizing helperexpected_tokens_per_layer(batch, ep, seq, top_k)— computes expected load (= worst_case / EP)Modified:
src/MaxText/layers/moe.pyWhen
config.ring_paged_stash=True, replaces thecheckpoint_name("moe_mlpwo")call withstash_fn/restore_fncalls that pack actual tokens into the shared buffer.Modified:
src/maxtext/configs/types.pyAdds two new config fields:
ring_paged_stash: bool = Falsering_paged_stash_safety_margin: float = 1.5TODOs / Open Questions
stash_bufandwrite_ptrneed to be threaded through thelax.scancarry indecoders.py/deepseek.py. Currently the per-layer integration is sketched but the cross-layer buffer threading is not yet wired up. Looking for guidance on the best pattern here.restore_fnbackward: The gradient accumulation inrestore_fn_bwdneeds review — specifically the scatter from compact → full shape.moe_mlpwi_0/moe_mlpwi_1: Same technique could apply to wi outputs if they are offloaded.