Skip to content

fix(metrics): make compute_pass_rate ragged-safe for over-sampled batches#2064

Open
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/metric-ragged-passrate
Open

fix(metrics): make compute_pass_rate ragged-safe for over-sampled batches#2064
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/metric-ragged-passrate

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Problem

compute_pass_rate in slime/utils/metric_utils.py assumes a rigid layout: every prompt-group has exactly group_size samples laid out contiguously, so it reshapes flat_rewards to (num_groups, group_size) behind

assert len(flat_rewards) == num_groups * group_size, f"{len(flat_rewards)=} {num_groups=} {group_size=}"
rewards_of_group = np.array(flat_rewards).reshape(num_groups, group_size)

Over-sampled batches break that assumption. Dynamic sampling, group replacement, and per-group sample drops produce a variable number of samples per prompt-group and a total that need not divide group_size — e.g. 51 trainable samples across mixed-size groups, not a rigid 12 * 4 = 48. On such a batch the reshape assert fires and crashes metric reporting (which otherwise runs fine), so the run dies at logging time rather than on a real training error.

Before vs After

Same input, a ragged batch of 51 rewards across 12 groups of mixed size ([4,4,3,4,4,4,5,4,4,4,4,7]).

Before — the rigid reshape path is the only path; the total does not divide group_size, so the assert crashes:

compute_pass_rate(flat_rewards, group_size=4, num_groups=12)
# AssertionError: len(flat_rewards)=51 num_groups=12 group_size=4

After — pass per-sample group_ids; rewards are bucketed by their actual group id and pass@k is estimated over the samples that actually exist per group:

compute_pass_rate(flat_rewards, group_size=4, group_ids=group_ids)
# {'pass@1': 0.528..., 'pass@2': 0.855..., 'pass@4': 1.0}

A smaller hand-checkable case shows the per-group semantics — group a = [1,1,0] (3 samples, 2 correct), group b = [0,0] (2 samples, 0 correct):

compute_pass_rate([1, 1, 0, 0, 0], group_size=4, group_ids=["a", "a", "a", "b", "b"])
# {'pass@1': 0.333..., 'pass@2': 0.5}
#   pass@1 = mean correct fraction = (2/3 + 0/2)/2 = 1/3
#   pass@2 = mean over groups with >= 2 samples: a -> 1.0, b -> 0.0  => 0.5
#   pass@4 is dropped: no group has >= 4 samples, so the rung cannot be estimated

Fix

Two parts: a ragged code path in the metric, and call-site wiring that feeds it real group ids.

Metric: optional group_ids in compute_pass_rate

  • When group_ids is given, rewards are bucketed by their actual group id and pass@k is estimated over the samples that actually exist per group.
  • pass@k for a rung is averaged only over groups with >= k samples — a group with fewer than k samples cannot define an unbiased pass@k draw, so it is excluded from that rung's mean rather than counted as a trivial 1.0.
  • Rungs whose every group is too small are dropped instead of asserting.
  • The ragged path validates len(flat_rewards) == len(group_ids) instead of the divisibility assert.

Call sites: pass real per-sample group ids

Train logging (log_passrate in slime/backends/megatron_utils/data.py) now buckets by the data source's real Sample.group_index:

  • _convert_samples_to_train_data packs group_indices = [sample.group_index for sample in samples] parallel to raw_reward. group_index is assigned per prompt-group by RolloutDataSource.get_samples, so it survives dynamic filtering / over-sampling and identifies each sample's actual group even when the batch is ragged.
  • _split_train_data_by_dp ships group_indices unsplit to every rank, exactly like raw_reward (pass@k is a whole-batch metric).
  • log_passrate passes it as group_ids when every entry is non-None. When the metadata is absent or incomplete — custom rollout or convert-samples functions that don't tag samples — it falls back to the rigid rollout_batch_size x n_samples_per_prompt reshape, i.e. exact legacy behavior.
  • log_rollout_data's generic reduce loop skips the new key so group ids are never averaged into a fake rollout/group_indices metric.

Eval logging (_log_eval_rollout_data in slime/ray/rollout.py): when the per-dataset eval dict carries samples and every sample has a non-None group_index, it passes [s.group_index for s in samples] as group_ids. The default eval path (eval_rollout_single_dataset) copies samples straight from the eval dataset and never sets group_index, so it keeps the rigid n_samples_per_eval_prompt layout unchanged; custom rollout functions that tag every sample get ragged bucketing instead of the divisibility assert.

Why this is the right fix

  • Real ids, not synthetic ones. The train path groups by the same group_index the data source assigned at sampling time, so pass@k buckets match the actual prompt groups even after per-group drops or replacements — no layout assumption to violate.

  • Default-path safety. With group_ids=None the function takes the exact legacy reshape (num_samples_per_group = full(num_groups, group_size), num_correct = sum(rewards == 1, axis=1)) and is numerically identical to prior behavior. Both call sites fall back to that path whenever group metadata is missing or incomplete, so existing setups (including custom rollout functions that don't set group_index) emit byte-identical metrics.

  • Minimal, no new abstraction. One optional argument and one branch in a single numpy-only function; the pass@k estimator (_estimate_pass_at_k) is reused unchanged. The wiring is one parallel list that travels next to raw_reward along the existing path.

  • CI-verifiable. Adds tests/utils/test_metric_utils.py pinning both regimes:

    • the ragged 51-sample shape that previously crashed the assert, with exact pass@k values (pass@1 == 0.5281746..., pass@2 == 0.8547619..., pass@4 == 1.0);
    • per-group pass@k eligibility and high-rung dropping (groups too small for a rung are excluded, not counted as 1.0);
    • the length-mismatch assert on the ragged path;
    • that the rigid (group_ids=None) path stays numerically identical to the legacy reshape.

    The test is registered in the CPU (num_gpus: 0) matrix via pr-test.yml.j2, and pr-test.yml is regenerated from the template by generate_github_workflows.py. The metric is numpy-only, so the suite runs in the CPU lane with no GPU.

@EazyReal EazyReal force-pushed the upstream-pr/metric-ragged-passrate branch 2 times, most recently from 03f2bca to 01bad5f Compare June 12, 2026 06:20
…ches

compute_pass_rate assumed a rigid layout: every prompt-group has exactly
group_size samples laid out contiguously, so it reshaped flat_rewards to
(num_groups, group_size) behind an
`assert len(flat_rewards) == num_groups * group_size`.

Over-sampled batches break that assumption. Dynamic sampling, group
replacement, and per-group sample drops yield a variable number of samples
per group and a total that need not divide group_size (e.g. 51 trainable
samples across mixed-size groups, not a rigid 8*4=32). On such a batch the
reshape assert fires and crashes metric reporting.

Fix: add an optional `group_ids` argument that selects a ragged path. When
group_ids is given, rewards are bucketed by their actual group id and pass@k
is estimated over the samples that actually exist per group:

  * pass@k for a rung is averaged only over groups with >= k samples — a group
    with fewer than k samples cannot define an unbiased pass@k draw, so it is
    excluded from that rung's mean rather than counted as a trivial 1.0;
  * rungs whose every group is too small are dropped instead of asserting.

The default path is unchanged: with group_ids=None the function still takes the
legacy reshape and stays numerically identical to the prior behavior.

Wire both call sites to pass real group ids:

  * train (megatron log_passrate): _convert_samples_to_train_data packs
    Sample.group_index (set by the data source for every default-path sample)
    as a "group_indices" list parallel to raw_reward, and
    _split_train_data_by_dp ships it unsplit to every rank exactly like
    raw_reward. log_passrate buckets by it when every entry is present and
    falls back to the rigid rollout_batch_size x n_samples_per_prompt reshape
    when any is None (custom rollout/convert functions that do not tag
    samples). log_rollout_data skips the new key in its generic metric loop.
  * eval (_log_eval_rollout_data): when the eval data dict carries samples and
    every sample has a non-None group_index, pass them as group_ids; the
    default eval path never sets group_index (samples are copied straight from
    the eval dataset), so it keeps the rigid n_samples_per_eval_prompt layout.

Adds tests/utils/test_metric_utils.py pinning both regimes — including the
ragged 51-sample shape that previously crashed the assert, the per-group
eligibility/rung-dropping semantics, and that the rigid path matches the legacy
reshape — and registers it in the CPU (num_gpus: 0) CI matrix via the
pr-test.yml.j2 template (pr-test.yml regenerated from the template).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@EazyReal EazyReal force-pushed the upstream-pr/metric-ragged-passrate branch from 01bad5f to 9927c9b Compare June 12, 2026 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant