fix(dp_schedule): drop trailing rollouts when the aligned micro-batch target exceeds the sample count#2065
Open
EazyReal wants to merge 1 commit into
Open
Conversation
4aa2ddd to
edd17fe
Compare
…exceeds the sample count In the long-trajectory dynamic-batch regime, packing puts roughly one sample per micro-batch, so rounding the per-step micro-batch count up to a multiple of align_to (= dp_size, or dp_size * mb_group under VPP) can exceed the step's sample count. `expand_bins_by_splitting` then cannot reach the target because singleton bins cannot split, and `build_dp_schedule` died on a bare AssertionError. Concretely at DP4 a step of 514 single-sample bins needs 516 micro-batches, which is unreachable. Fix: before aligning, drop the minimum number of whole TRAILING rollouts until the up-rounded target is reachable, then re-pack and record the kept rollout count in `global_batch_sizes` so the per-group loss denominators track what actually trains. Rollouts stay atomic (GRPO groups are never split across the drop), and the floor is sample-count based: keep dropping while the target is unreachable as long as a full aligned block (>= align_to samples) remains. For the genuine residual where rollout atomicity means no kept prefix has a sample count divisible by align_to, fail loud with an actionable ValueError naming the step, sample count, and align_to instead of a bare assert. CI-verifiable: tests/test_dp_schedule.py is CPU-only / torch-free and runs under the existing unit-test job. Three new @pytest.mark.unit cases cover the ragged-fan-out drop, dropping past the old dp_size floor down to a tiling prefix, and the unreachable-alignment ValueError; the existing suite continues to pass. Default path safety: the drop loop is gated on `use_dynamic_batch_size and align_to > 1` and only triggers when the aligned target already exceeds the sample count, so the common one-sample-per-rollout / static-batch paths are untouched and `global_batch_sizes` stays equal to `global_batch_size`.
edd17fe to
4ce9792
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
build_dp_scheduleinslime/utils/dp_schedule.pypacks each step's samples into micro-batches (step_mbs) and then rounds the micro-batch count up toalign_to = dp_size * (mb_group if vpp > 1 else 1)so every DP rank gets the samenum_microbatches. In the long-trajectory dynamic-batch regime (use_dynamic_batch_size), first-fit packing puts roughly one sample per micro-batch, so the micro-batch count equals the sample count. When the up-rounded target exceeds the step's sample count,expand_bins_by_splittingcannot reach it — its docstring is explicit that it "stops early if every remaining bin is a singleton (no bin can be split further)" — solen(step_mbs)stayed belowtarget_Kand the function died on a bareAssertionError.Concrete trigger: at DP4 a step of 514 single-sample micro-batches needs 516 to align (
514 -> ceil(514/4)*4 = 516), which is unreachable because all 514 bins are singletons.Before vs After
Same input — DP4, dynamic batch,
max_tokens_per_gpu=32768, 512 rollouts where 2 fan out to 2 samples each (a GRPO-style group), giving 514 single-sample micro-batches:Before — align-up target
516 > 514, singletons cannot split, bare assert:After — drops the 2 trailing single-sample rollouts so the kept prefix (510 rollouts / 512 samples) tiles by 4, logs the drop, and records the kept count:
The drop floor is sample-count based, not rollout-count based: reachability is bounded by the sample count (the splitter can grow to at most one bin per sample), so the loop keeps dropping while the target is unreachable as long as a full aligned block (
>= align_tosamples) remains. With ragged rollout sizes[3,2,3,1,2](11 long samples) at DP4, the oldlen > dp_sizefloor would have kept 4 rollouts / 9 samples — still unaligned, still a crash; the sample-count floor keeps dropping to 3 rollouts / 8 samples (global_batch_sizes == [3]), which aligns.For the genuine residual where rollout atomicity means no kept prefix has a sample count divisible by
align_to(e.g. sizes[1,1,4,1,3,1], prefix sample counts1,2,6,7,10,11— none divisible by 4), we now fail loud with an actionableValueErrorinstead of a bare assert:Fix
In
build_dp_schedule, before aligning, drop the minimum number of whole trailing rollouts until the up-rounded target is reachable, then re-pack. Rollouts stay atomic (a GRPO group / multi-sample rollout is never split across the drop), and the kept rollout count is recorded inglobal_batch_sizes[s]so per-group loss denominators track what actually trains. The unreachable residual raises the actionableValueErrorabove. The micro-batch count helper is factored into_aligned_targetand sample collection into_collect_step_samples/_packfor reuse across the drop loop.Why this is the right fix
use_dynamic_batch_size and align_to > 1and only runs when the aligned target already exceeds the step's sample count. The common one-sample-per-rollout case and the entire static-batch path are untouched, andglobal_batch_sizesstays equal toglobal_batch_size.global_batch_size) keeps per-group loss denominators consistent with the samples that actually train after a drop.0..k-1with no holes.ValueErrornaming the step, kept sample count, andalign_to, and points atglobal_batch_size/n_samples_per_prompt/max_tokens_per_gpu— replacing the bare assert.tests/test_dp_schedule.pyis CPU-only and torch-free, so it runs under the existing unit-test job. Three new@pytest.mark.unitcases cover the ragged fan-out drop (514 -> kept prefix that tiles), dropping past the naivedp_sizefloor down to a tiling prefix (8 samples), and the unreachable-alignmentValueError. The existing suite (including the FLOPs-balancing tests) continues to pass.