Skip to content

perf(ppo): gather response/loss-mask rows before log-prob+entropy CE (supersedes #2011)#2076

Open
Mantissagithub wants to merge 2 commits into
THUDM:mainfrom
Mantissagithub:perf/logprob-response-only-gather
Open

perf(ppo): gather response/loss-mask rows before log-prob+entropy CE (supersedes #2011)#2076
Mantissagithub wants to merge 2 commits into
THUDM:mainfrom
Mantissagithub:perf/logprob-response-only-gather

Conversation

@Mantissagithub

Copy link
Copy Markdown

Summary

This supersedes #2011. I closed #2011 deliberately — rather than land it as-is, I wanted to push the
idea to where it actually resolves the #1951 OOM, and continue the work here. #2011 reduced the
constant factor of the log-prob/entropy cross-entropy; this PR reduces its asymptotic size by
running the cross-entropy only on the tokens whose result is actually consumed.

It is a general-purpose, drop-in memory optimization (off by default, opt-in flags), it touches no
Megatron internals, and it is covered by CI parity tests plus a single-GPU reproducer — in line with
CONTRIBUTING.md (general RL optimization, verifiable correctness, clear benchmark).

cc @zhuzilin @huang3eng — thank you both for the review on #2011; the entropy-only backward test
@huang3eng asked for is included here as well.

Background / motivation

For a packed micro-batch the output layer emits logits

$$Z \in \mathbb{R}^{T \times V} \quad (\text{fp32}),$$

where $T$ is the total number of positions in the micro-batch (prompt + response, summed over the
packed samples) and $V$ is the per-rank vocabulary. get_log_probs_and_entropy computes, for every
row $t$,

$$\ell_t = z_{t, y_t} - \log!\sum_{v} e^{z_{t,v}}, \qquad H_t = -\sum_{v} p_{t,v},\log p_{t,v}, \quad p_{t,\cdot} = \mathrm{softmax}(z_t).$$

But the consumer (_extract_per_sample) immediately throws away every row outside the response
windows

$$\mathcal{R} = \bigcup_i [,s_i,,e_i,), \qquad T' := |\mathcal{R}|.$$

So $\ell_t, H_t$ for $t \notin \mathcal{R}$ are computed, allocated, and discarded. In RL post-training
prompts usually dominate, i.e. the response fraction $\rho := T'/T$ is small (long retool / reasoning
contexts), so most of the dominant tensor is wasted.

The peak of the fused cross-entropy is driven by the full-vocab buffers, which scale as
$\Theta(T,V)$ (the retained softmax/grad buffer alone is $T,V\cdot 4$ bytes in fp32). #2011 lowered
the constant (one buffer instead of two clones) but the scaling stayed $\Theta(T,V)$ — which is why,
as I noted there, it could not fix #1951: at $T\approx 205\text{k}$, $V=76032$ the logits tensor
itself is $205280 \times 76032 \times 4 \approx 58.1$ GiB.

What this PR does

Gather the response rows before the cross-entropy, run CE on the compact tensor, and scatter the
results back to full length:

$$Z_{\mathcal{R}} = Z[\mathcal{R}] \in \mathbb{R}^{T' \times V}, \qquad (\ell, H)_{\mathcal{R}} = \mathrm{CE}(Z_{\mathcal{R}}),$$

so the peak drops from $\Theta(T,V)$ to $\Theta(T',V) = \Theta(\rho,T,V)$. This stacks
multiplicatively with the existing --log-probs-chunk-size $c$: the forward working set is
$O(\min(c, T'),V)$ and the retained buffer is $O(T',V)$.

Why no custom backward is needed

The gather is index_select, whose vector–Jacobian product is a scatter:

$$\Big(\frac{\partial L}{\partial Z}\Big)_t = \begin{cases} \big(\partial L/\partial Z_{\mathcal{R}}\big)_{\pi(t)}, & t \in \mathcal{R},\[4pt] 0, & t \notin \mathcal{R}, \end{cases}$$

where $\pi$ maps a full-length index to its position in $\mathcal{R}$. This is exactly correct:
rows $t \notin \mathcal{R}$ have no path to the loss $L$ (their $\ell_t, H_t$ are discarded), so
$\partial L/\partial Z_t = 0$ independently of the optimization. Autograd's stock index_select
backward produces precisely this, so the fused Function from #2011 is reused unchanged and results are
bit-identical to the full path.

Optional loss-mask restriction

With --log-probs-loss-mask-only the kept set is further reduced to the tokens that contribute to the
loss,

$$\mathcal{M} = {, t \in \mathcal{R} : m_t = 1 ,} \subseteq \mathcal{R},$$

giving peak $\Theta(|\mathcal{M}|,V)$. Positions in $\mathcal{R}\setminus\mathcal{M}$ are returned as
$0$, so this is only valid on the policy-loss path where they are masked out downstream; hence it is a
separate opt-in flag and validated to require --log-probs-response-only.

Flags (both default off → behavior byte-identical to current main)

  • --log-probs-response-only — gather response rows before CE. General; valid for both log-probs and
    entropy outputs; all CP layouts (cp1 thd/bshd, zigzag CP, allgather CP).
  • --log-probs-loss-mask-only — further restrict to loss_mask == 1 (policy-loss path).

The keep-index is built in lock-step with _extract_per_sample (same per-CP-mode offset math, single
source of truth), so the gathered set is exactly the consumed set.

Benchmark (1×H200 141 GB, real measurement)

Reproducer: tools/repro_1951.py --sweep. Setting matches #2011's regime —
$T=205280,\ V=76032$, fp32, --with-entropy --backward --log-probs-chunk-size 1024.
$\rho$ is the response fraction (emulating the gather by running CE on $\rho T$ rows):

$\rho = T'/T$ $T'$ peak result
1.0 (no gather, = today) 205,280 tries to alloc 58.14 GiB OOM
0.5 102,640 87.2 GiB fits
0.25 51,320 43.6 GiB fits
0.0625 12,830 10.9 GiB fits

The failing allocation is exactly the 58.14 GiB from #1951/#2011, and bounding $T \to T'$ turns the
OOM step into a fitting one, with peak scaling linearly in $\rho$ — i.e. this is the part #2011 said
"bounding $T$ is what does," done directly.

Correctness / tests

tests/test_logprob_entropy_fused.py (extended):

  • forward parity: response-only path returns log-probs/entropy identical to the full path
    (atol=rtol=1e-6), with and without entropy, chunked and unchunked;
  • backward parity: grads on the logits match the full path within bf16 tolerance (non-response rows
    are zero-grad in both);
  • loss-mask path: equals the full path on loss_mask == 1 positions and returns $0$ on masked ones.

pytest tests/test_logprob_entropy_fused.py tests/test_chunked_gae.py passes; pre-commit (black)
clean. No Megatron internals are modified.

Changes

  • slime/backends/megatron_utils/loss.py_response_keep_index (all CP layouts) + gather/scatter
    in get_log_probs_and_entropy; optional full_loss_mask kwarg.
  • slime/utils/arguments.py--log-probs-response-only, --log-probs-loss-mask-only (+ validation).
  • slime/backends/megatron_utils/model.py — thread full_loss_mask through the forward_only partial.
  • tests/test_logprob_entropy_fused.py — parity tests above.
  • tools/repro_1951.py--response-frac / --sweep peak-memory sweep.
  • examples/retool/retool_qwen3_4b_rl.sh — enable --log-probs-response-only (long-prompt traces).

Relates to #1951. Supersedes #2011.

…nse gather

  Two stacked memory optimizations for the log-prob/entropy cross-entropy in
  slime/utils/ppo_utils.py and slime/backends/megatron_utils/loss.py.

  1. Fuse log-prob + entropy into one autograd Function that reuses Megatron's
     softmax buffer in place as the gradient buffer, so the only extra full-vocab
     allocation is log_softmax (and only when entropy grad flows). --log-probs-chunk-size
     bounds the forward working set.

  2. Gather only the rows _extract_per_sample consumes (response windows) before CE,
     shrinking the dominant tensor from [T, V] to [T', V] (T' = response tokens) and
     scattering results back. index_select's backward is a scatter and non-response
     rows have no path to the loss, so grads are bit-identical and no custom backward
     is needed. This bounds T itself, fixing the THUDM#1951 regime the fused kernel alone
     could not (the [T, V] tensor is ~58 GiB at T~205k, V=76032).

  Flags (both default off -> behavior byte-identical to current main):
  - --log-probs-response-only: gather response window (cp1 thd/bshd, zigzag CP, allgather CP)
  - --log-probs-loss-mask-only: further restrict to loss_mask==1 (policy-loss path;
    masked positions return 0; requires --log-probs-response-only)

  Peak: O(T*V) -> O(T'*V), stacking with chunking. On 1xH200 (T=205280, V=76032, fp32,
  entropy+backward, chunk 1024): full T OOMs at the 58.14 GiB allocation; response
  fraction 0.25 fits in 43.6 GiB, scaling linearly.

  Tests (tests/test_logprob_entropy_fused.py): forward parity (chunked/unchunked,
  +/- entropy), backward parity within bf16 tolerance, TP=2 gloo gradcheck, response-only
  and loss-mask parity vs the full path. tools/repro_1951.py --sweep reports the
  peak-memory sweep. No Megatron internals modified.
@Mantissagithub Mantissagithub force-pushed the perf/logprob-response-only-gather branch from d9cba76 to 119c9c1 Compare June 14, 2026 14:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Question] retool example: compute_log_probs(logits.clone(), tokens, tp_group) torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 58.15 GiB.

1 participant