Skip to content

fix(dp_schedule): drop trailing rollouts when the aligned micro-batch target exceeds the sample count#2065

Open
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/dp-schedule-trailing-groups
Open

fix(dp_schedule): drop trailing rollouts when the aligned micro-batch target exceeds the sample count#2065
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/dp-schedule-trailing-groups

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Problem

build_dp_schedule in slime/utils/dp_schedule.py packs each step's samples into micro-batches (step_mbs) and then rounds the micro-batch count up to align_to = dp_size * (mb_group if vpp > 1 else 1) so every DP rank gets the same num_microbatches. In the long-trajectory dynamic-batch regime (use_dynamic_batch_size), first-fit packing puts roughly one sample per micro-batch, so the micro-batch count equals the sample count. When the up-rounded target exceeds the step's sample count, expand_bins_by_splitting cannot reach it — its docstring is explicit that it "stops early if every remaining bin is a singleton (no bin can be split further)" — so len(step_mbs) stayed below target_K and the function died on a bare AssertionError.

Concrete trigger: at DP4 a step of 514 single-sample micro-batches needs 516 to align (514 -> ceil(514/4)*4 = 516), which is unreachable because all 514 bins are singletons.

Before vs After

Same input — DP4, dynamic batch, max_tokens_per_gpu=32768, 512 rollouts where 2 fan out to 2 samples each (a GRPO-style group), giving 514 single-sample micro-batches:

Before — align-up target 516 > 514, singletons cannot split, bare assert:

build_dp_schedule(args, tp, total_lengths,
                  global_batch_size=512, rollout_indices=rollout_indices)
# AssertionError: dynamic path: could only produce 514 mbs after maximal
#   splitting; need 516. step 0 has 514 samples, below the alignment threshold (4).

After — drops the 2 trailing single-sample rollouts so the kept prefix (510 rollouts / 512 samples) tiles by 4, logs the drop, and records the kept count:

partitions, mbi, num_microbatches, global_batch_sizes = build_dp_schedule(...)
# WARNING [dp_schedule] step 0: dropped 2 trailing rollout(s) (510 kept, 512 samples)
#   so the aligned micro-batch target stays reachable (dp_size=4, align_to=4).
assert global_batch_sizes == [510]          # was constant 512; now the kept count
assert sum(len(m) for m in mbi) % 4 == 0    # 512 mbs, evenly divisible across DP4
# dropped rollouts stay atomic: kept rollouts are 0..509 with no holes, no group split

The drop floor is sample-count based, not rollout-count based: reachability is bounded by the sample count (the splitter can grow to at most one bin per sample), so the loop keeps dropping while the target is unreachable as long as a full aligned block (>= align_to samples) remains. With ragged rollout sizes [3,2,3,1,2] (11 long samples) at DP4, the old len > dp_size floor would have kept 4 rollouts / 9 samples — still unaligned, still a crash; the sample-count floor keeps dropping to 3 rollouts / 8 samples (global_batch_sizes == [3]), which aligns.

For the genuine residual where rollout atomicity means no kept prefix has a sample count divisible by align_to (e.g. sizes [1,1,4,1,3,1], prefix sample counts 1,2,6,7,10,11 — none divisible by 4), we now fail loud with an actionable ValueError instead of a bare assert:

ValueError: dp_schedule step 0: cannot align micro-batches to a multiple of
  align_to=4 (dp_size=4). After dropping trailing rollouts the step has 11 samples
  packed into 11 singleton micro-batches, but the aligned target is 12 and singleton
  bins cannot split further. This happens with ragged rollout sizes where every long
  sample fills its own micro-batch. Adjust global_batch_size / n_samples_per_prompt
  (or max_tokens_per_gpu) so each step's kept-sample count can reach a multiple of align_to.

Fix

In build_dp_schedule, before aligning, drop the minimum number of whole trailing rollouts until the up-rounded target is reachable, then re-pack. Rollouts stay atomic (a GRPO group / multi-sample rollout is never split across the drop), and the kept rollout count is recorded in global_batch_sizes[s] so per-group loss denominators track what actually trains. The unreachable residual raises the actionable ValueError above. The micro-batch count helper is factored into _aligned_target and sample collection into _collect_step_samples / _pack for reuse across the drop loop.

Why this is the right fix

  • Default-path safe / no-op when not triggered. The drop loop is gated on use_dynamic_batch_size and align_to > 1 and only runs when the aligned target already exceeds the step's sample count. The common one-sample-per-rollout case and the entire static-batch path are untouched, and global_batch_sizes stays equal to global_batch_size.
  • Correct denominator. Recording the kept rollout count (not a constant global_batch_size) keeps per-group loss denominators consistent with the samples that actually train after a drop.
  • Atomicity preserved. Whole trailing rollouts are dropped, never a partial GRPO group; kept rollouts remain a contiguous prefix 0..k-1 with no holes.
  • Loud over silent. The unreachable-alignment case raises a ValueError naming the step, kept sample count, and align_to, and points at global_batch_size / n_samples_per_prompt / max_tokens_per_gpu — replacing the bare assert.
  • CI-verifiable, no new abstraction. tests/test_dp_schedule.py is CPU-only and torch-free, so it runs under the existing unit-test job. Three new @pytest.mark.unit cases cover the ragged fan-out drop (514 -> kept prefix that tiles), dropping past the naive dp_size floor down to a tiling prefix (8 samples), and the unreachable-alignment ValueError. The existing suite (including the FLOPs-balancing tests) continues to pass.

@EazyReal EazyReal force-pushed the upstream-pr/dp-schedule-trailing-groups branch from 4aa2ddd to edd17fe Compare June 12, 2026 06:03
…exceeds the sample count

In the long-trajectory dynamic-batch regime, packing puts roughly one sample
per micro-batch, so rounding the per-step micro-batch count up to a multiple of
align_to (= dp_size, or dp_size * mb_group under VPP) can exceed the step's
sample count. `expand_bins_by_splitting` then cannot reach the target because
singleton bins cannot split, and `build_dp_schedule` died on a bare
AssertionError. Concretely at DP4 a step of 514 single-sample bins needs 516
micro-batches, which is unreachable.

Fix: before aligning, drop the minimum number of whole TRAILING rollouts until
the up-rounded target is reachable, then re-pack and record the kept rollout
count in `global_batch_sizes` so the per-group loss denominators track what
actually trains. Rollouts stay atomic (GRPO groups are never split across the
drop), and the floor is sample-count based: keep dropping while the target is
unreachable as long as a full aligned block (>= align_to samples) remains. For
the genuine residual where rollout atomicity means no kept prefix has a
sample count divisible by align_to, fail loud with an actionable ValueError
naming the step, sample count, and align_to instead of a bare assert.

CI-verifiable: tests/test_dp_schedule.py is CPU-only / torch-free and runs under
the existing unit-test job. Three new @pytest.mark.unit cases cover the
ragged-fan-out drop, dropping past the old dp_size floor down to a tiling
prefix, and the unreachable-alignment ValueError; the existing suite continues
to pass.

Default path safety: the drop loop is gated on `use_dynamic_batch_size and
align_to > 1` and only triggers when the aligned target already exceeds the
sample count, so the common one-sample-per-rollout / static-batch paths are
untouched and `global_batch_sizes` stays equal to `global_batch_size`.
@EazyReal EazyReal force-pushed the upstream-pr/dp-schedule-trailing-groups branch from edd17fe to 4ce9792 Compare June 12, 2026 06:21
@EazyReal EazyReal marked this pull request as draft June 12, 2026 06:28
@EazyReal EazyReal marked this pull request as ready for review June 12, 2026 06:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant