perf(ppo): gather response/loss-mask rows before log-prob+entropy CE (supersedes #2011)#2076
Open
Mantissagithub wants to merge 2 commits into
Open
perf(ppo): gather response/loss-mask rows before log-prob+entropy CE (supersedes #2011)#2076Mantissagithub wants to merge 2 commits into
Mantissagithub wants to merge 2 commits into
Conversation
…nse gather
Two stacked memory optimizations for the log-prob/entropy cross-entropy in
slime/utils/ppo_utils.py and slime/backends/megatron_utils/loss.py.
1. Fuse log-prob + entropy into one autograd Function that reuses Megatron's
softmax buffer in place as the gradient buffer, so the only extra full-vocab
allocation is log_softmax (and only when entropy grad flows). --log-probs-chunk-size
bounds the forward working set.
2. Gather only the rows _extract_per_sample consumes (response windows) before CE,
shrinking the dominant tensor from [T, V] to [T', V] (T' = response tokens) and
scattering results back. index_select's backward is a scatter and non-response
rows have no path to the loss, so grads are bit-identical and no custom backward
is needed. This bounds T itself, fixing the THUDM#1951 regime the fused kernel alone
could not (the [T, V] tensor is ~58 GiB at T~205k, V=76032).
Flags (both default off -> behavior byte-identical to current main):
- --log-probs-response-only: gather response window (cp1 thd/bshd, zigzag CP, allgather CP)
- --log-probs-loss-mask-only: further restrict to loss_mask==1 (policy-loss path;
masked positions return 0; requires --log-probs-response-only)
Peak: O(T*V) -> O(T'*V), stacking with chunking. On 1xH200 (T=205280, V=76032, fp32,
entropy+backward, chunk 1024): full T OOMs at the 58.14 GiB allocation; response
fraction 0.25 fits in 43.6 GiB, scaling linearly.
Tests (tests/test_logprob_entropy_fused.py): forward parity (chunked/unchunked,
+/- entropy), backward parity within bf16 tolerance, TP=2 gloo gradcheck, response-only
and loss-mask parity vs the full path. tools/repro_1951.py --sweep reports the
peak-memory sweep. No Megatron internals modified.
d9cba76 to
119c9c1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This supersedes #2011. I closed #2011 deliberately — rather than land it as-is, I wanted to push the
idea to where it actually resolves the #1951 OOM, and continue the work here. #2011 reduced the
constant factor of the log-prob/entropy cross-entropy; this PR reduces its asymptotic size by
running the cross-entropy only on the tokens whose result is actually consumed.
It is a general-purpose, drop-in memory optimization (off by default, opt-in flags), it touches no
Megatron internals, and it is covered by CI parity tests plus a single-GPU reproducer — in line with
CONTRIBUTING.md(general RL optimization, verifiable correctness, clear benchmark).cc @zhuzilin @huang3eng — thank you both for the review on #2011; the entropy-only backward test
@huang3eng asked for is included here as well.
Background / motivation
For a packed micro-batch the output layer emits logits
where$T$ is the total number of positions in the micro-batch (prompt + response, summed over the$V$ is the per-rank vocabulary. $t$ ,
packed samples) and
get_log_probs_and_entropycomputes, for everyrow
But the consumer (
_extract_per_sample) immediately throws away every row outside the responsewindows
So$\ell_t, H_t$ for $t \notin \mathcal{R}$ are computed, allocated, and discarded. In RL post-training$\rho := T'/T$ is small (long retool / reasoning
prompts usually dominate, i.e. the response fraction
contexts), so most of the dominant tensor is wasted.
The peak of the fused cross-entropy is driven by the full-vocab buffers, which scale as
$\Theta(T,V)$ (the retained softmax/grad buffer alone is $T,V\cdot 4$ bytes in fp32). #2011 lowered$\Theta(T,V)$ — which is why,$T\approx 205\text{k}$ , $V=76032$ the logits tensor$205280 \times 76032 \times 4 \approx 58.1$ GiB.
the constant (one buffer instead of two clones) but the scaling stayed
as I noted there, it could not fix #1951: at
itself is
What this PR does
Gather the response rows before the cross-entropy, run CE on the compact tensor, and scatter the
results back to full length:
so the peak drops from$\Theta(T,V)$ to $\Theta(T',V) = \Theta(\rho,T,V)$ . This stacks$c$ : the forward working set is
$O(\min(c, T'),V)$ and the retained buffer is $O(T',V)$ .
multiplicatively with the existing
--log-probs-chunk-sizeWhy no custom backward is needed
The gather is
index_select, whose vector–Jacobian product is a scatter:where$\pi$ maps a full-length index to its position in $\mathcal{R}$ . This is exactly correct:$t \notin \mathcal{R}$ have no path to the loss $L$ (their $\ell_t, H_t$ are discarded), so
$\partial L/\partial Z_t = 0$ independently of the optimization. Autograd's stock
rows
index_selectbackward produces precisely this, so the fused Function from #2011 is reused unchanged and results are
bit-identical to the full path.
Optional loss-mask restriction
With
--log-probs-loss-mask-onlythe kept set is further reduced to the tokens that contribute to theloss,
giving peak$\Theta(|\mathcal{M}|,V)$ . Positions in $\mathcal{R}\setminus\mathcal{M}$ are returned as
$0$ , so this is only valid on the policy-loss path where they are masked out downstream; hence it is a
separate opt-in flag and validated to require
--log-probs-response-only.Flags (both default off → behavior byte-identical to current
main)--log-probs-response-only— gather response rows before CE. General; valid for both log-probs andentropy outputs; all CP layouts (cp1 thd/bshd, zigzag CP, allgather CP).
--log-probs-loss-mask-only— further restrict toloss_mask == 1(policy-loss path).The keep-index is built in lock-step with
_extract_per_sample(same per-CP-mode offset math, singlesource of truth), so the gathered set is exactly the consumed set.
Benchmark (1×H200 141 GB, real measurement)
Reproducer:
$T=205280,\ V=76032$ , fp32,
$\rho$ is the response fraction (emulating the gather by running CE on $\rho T$ rows):
tools/repro_1951.py --sweep. Setting matches #2011's regime —--with-entropy --backward --log-probs-chunk-size 1024.The failing allocation is exactly the 58.14 GiB from #1951/#2011, and bounding$T \to T'$ turns the$\rho$ — i.e. this is the part #2011 said$T$ is what does," done directly.
OOM step into a fitting one, with peak scaling linearly in
"bounding
Correctness / tests
tests/test_logprob_entropy_fused.py(extended):(
atol=rtol=1e-6), with and without entropy, chunked and unchunked;are zero-grad in both);
loss_mask == 1positions and returnspytest tests/test_logprob_entropy_fused.py tests/test_chunked_gae.pypasses; pre-commit (black)clean. No Megatron internals are modified.
Changes
slime/backends/megatron_utils/loss.py—_response_keep_index(all CP layouts) + gather/scatterin
get_log_probs_and_entropy; optionalfull_loss_maskkwarg.slime/utils/arguments.py—--log-probs-response-only,--log-probs-loss-mask-only(+ validation).slime/backends/megatron_utils/model.py— threadfull_loss_maskthrough theforward_onlypartial.tests/test_logprob_entropy_fused.py— parity tests above.tools/repro_1951.py—--response-frac/--sweeppeak-memory sweep.examples/retool/retool_qwen3_4b_rl.sh— enable--log-probs-response-only(long-prompt traces).Relates to #1951. Supersedes #2011.