fix(metrics): make compute_pass_rate ragged-safe for over-sampled batches#2064
Open
EazyReal wants to merge 1 commit into
Open
fix(metrics): make compute_pass_rate ragged-safe for over-sampled batches#2064EazyReal wants to merge 1 commit into
EazyReal wants to merge 1 commit into
Conversation
03f2bca to
01bad5f
Compare
…ches
compute_pass_rate assumed a rigid layout: every prompt-group has exactly
group_size samples laid out contiguously, so it reshaped flat_rewards to
(num_groups, group_size) behind an
`assert len(flat_rewards) == num_groups * group_size`.
Over-sampled batches break that assumption. Dynamic sampling, group
replacement, and per-group sample drops yield a variable number of samples
per group and a total that need not divide group_size (e.g. 51 trainable
samples across mixed-size groups, not a rigid 8*4=32). On such a batch the
reshape assert fires and crashes metric reporting.
Fix: add an optional `group_ids` argument that selects a ragged path. When
group_ids is given, rewards are bucketed by their actual group id and pass@k
is estimated over the samples that actually exist per group:
* pass@k for a rung is averaged only over groups with >= k samples — a group
with fewer than k samples cannot define an unbiased pass@k draw, so it is
excluded from that rung's mean rather than counted as a trivial 1.0;
* rungs whose every group is too small are dropped instead of asserting.
The default path is unchanged: with group_ids=None the function still takes the
legacy reshape and stays numerically identical to the prior behavior.
Wire both call sites to pass real group ids:
* train (megatron log_passrate): _convert_samples_to_train_data packs
Sample.group_index (set by the data source for every default-path sample)
as a "group_indices" list parallel to raw_reward, and
_split_train_data_by_dp ships it unsplit to every rank exactly like
raw_reward. log_passrate buckets by it when every entry is present and
falls back to the rigid rollout_batch_size x n_samples_per_prompt reshape
when any is None (custom rollout/convert functions that do not tag
samples). log_rollout_data skips the new key in its generic metric loop.
* eval (_log_eval_rollout_data): when the eval data dict carries samples and
every sample has a non-None group_index, pass them as group_ids; the
default eval path never sets group_index (samples are copied straight from
the eval dataset), so it keeps the rigid n_samples_per_eval_prompt layout.
Adds tests/utils/test_metric_utils.py pinning both regimes — including the
ragged 51-sample shape that previously crashed the assert, the per-group
eligibility/rung-dropping semantics, and that the rigid path matches the legacy
reshape — and registers it in the CPU (num_gpus: 0) CI matrix via the
pr-test.yml.j2 template (pr-test.yml regenerated from the template).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
01bad5f to
9927c9b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
compute_pass_rateinslime/utils/metric_utils.pyassumes a rigid layout: every prompt-group has exactlygroup_sizesamples laid out contiguously, so it reshapesflat_rewardsto(num_groups, group_size)behindOver-sampled batches break that assumption. Dynamic sampling, group replacement, and per-group sample drops produce a variable number of samples per prompt-group and a total that need not divide
group_size— e.g. 51 trainable samples across mixed-size groups, not a rigid12 * 4 = 48. On such a batch the reshape assert fires and crashes metric reporting (which otherwise runs fine), so the run dies at logging time rather than on a real training error.Before vs After
Same input, a ragged batch of 51 rewards across 12 groups of mixed size (
[4,4,3,4,4,4,5,4,4,4,4,7]).Before — the rigid reshape path is the only path; the total does not divide
group_size, so the assert crashes:After — pass per-sample
group_ids; rewards are bucketed by their actual group id and pass@k is estimated over the samples that actually exist per group:A smaller hand-checkable case shows the per-group semantics — group
a = [1,1,0](3 samples, 2 correct), groupb = [0,0](2 samples, 0 correct):Fix
Two parts: a ragged code path in the metric, and call-site wiring that feeds it real group ids.
Metric: optional
group_idsincompute_pass_rategroup_idsis given, rewards are bucketed by their actual group id and pass@k is estimated over the samples that actually exist per group.>= ksamples — a group with fewer thanksamples cannot define an unbiased pass@k draw, so it is excluded from that rung's mean rather than counted as a trivial1.0.len(flat_rewards) == len(group_ids)instead of the divisibility assert.Call sites: pass real per-sample group ids
Train logging (
log_passrateinslime/backends/megatron_utils/data.py) now buckets by the data source's realSample.group_index:_convert_samples_to_train_datapacksgroup_indices = [sample.group_index for sample in samples]parallel toraw_reward.group_indexis assigned per prompt-group byRolloutDataSource.get_samples, so it survives dynamic filtering / over-sampling and identifies each sample's actual group even when the batch is ragged._split_train_data_by_dpshipsgroup_indicesunsplit to every rank, exactly likeraw_reward(pass@k is a whole-batch metric).log_passratepasses it asgroup_idswhen every entry is non-None. When the metadata is absent or incomplete — custom rollout or convert-samples functions that don't tag samples — it falls back to the rigidrollout_batch_size x n_samples_per_promptreshape, i.e. exact legacy behavior.log_rollout_data's generic reduce loop skips the new key so group ids are never averaged into a fakerollout/group_indicesmetric.Eval logging (
_log_eval_rollout_datainslime/ray/rollout.py): when the per-dataset eval dict carriessamplesand every sample has a non-Nonegroup_index, it passes[s.group_index for s in samples]asgroup_ids. The default eval path (eval_rollout_single_dataset) copies samples straight from the eval dataset and never setsgroup_index, so it keeps the rigidn_samples_per_eval_promptlayout unchanged; custom rollout functions that tag every sample get ragged bucketing instead of the divisibility assert.Why this is the right fix
Real ids, not synthetic ones. The train path groups by the same
group_indexthe data source assigned at sampling time, so pass@k buckets match the actual prompt groups even after per-group drops or replacements — no layout assumption to violate.Default-path safety. With
group_ids=Nonethe function takes the exact legacy reshape (num_samples_per_group = full(num_groups, group_size),num_correct = sum(rewards == 1, axis=1)) and is numerically identical to prior behavior. Both call sites fall back to that path whenever group metadata is missing or incomplete, so existing setups (including custom rollout functions that don't setgroup_index) emit byte-identical metrics.Minimal, no new abstraction. One optional argument and one branch in a single numpy-only function; the pass@k estimator (
_estimate_pass_at_k) is reused unchanged. The wiring is one parallel list that travels next toraw_rewardalong the existing path.CI-verifiable. Adds
tests/utils/test_metric_utils.pypinning both regimes:pass@1 == 0.5281746...,pass@2 == 0.8547619...,pass@4 == 1.0);1.0);group_ids=None) path stays numerically identical to the legacy reshape.The test is registered in the CPU (
num_gpus: 0) matrix viapr-test.yml.j2, andpr-test.ymlis regenerated from the template bygenerate_github_workflows.py. The metric is numpy-only, so the suite runs in the CPU lane with no GPU.