Skip to content

Commit f7485d4

Browse files
committed
attempt 1
1 parent a612201 commit f7485d4

7 files changed

Lines changed: 213 additions & 157 deletions

File tree

verifiers/envs/environment.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import logging
99
import signal
10+
import sys
1011
import time
1112
import uuid
1213
from abc import ABC, abstractmethod
@@ -28,6 +29,7 @@
2829

2930
from openai import AsyncOpenAI, BadRequestError, OpenAI
3031

32+
from verifiers.utils.eval_utils import filter_inputs
3133
from verifiers.utils.worker_utils import get_free_port
3234
from verifiers.workers.client.zmq_env_client import ZMQEnvClient
3335
from verifiers.workers.server.zmq_env_server import ZMQEnvServer
@@ -71,7 +73,10 @@
7173
from verifiers.utils.save_utils import (
7274
GenerateOutputsBuilder,
7375
make_dataset,
74-
save_generate_outputs,
76+
push_results_to_hf_hub,
77+
save_metadata,
78+
save_new_outputs,
79+
save_outputs,
7580
state_to_output,
7681
)
7782
from verifiers.utils.token_utils import (
@@ -673,12 +678,15 @@ async def init_state(
673678
state_input["info"] = json.loads(state_input["info"])
674679
if "task" not in state_input:
675680
state_input["task"] = self.env_id or "default"
681+
# Extract rollout_idx before creating RolloutInput (it's not part of RolloutInput)
682+
rollout_idx = state_input.pop("rollout_idx", 0)
676683
state = State(input=RolloutInput(**state_input)) # type: ignore[missing-typed-dict-key]
677684
state["client"] = client
678685
state["model"] = model
679686
state["sampling_args"] = sampling_args
680687
state["is_completed"] = False
681688
state["is_truncated"] = False
689+
state["rollout_idx"] = rollout_idx
682690
state["oai_tools"] = None
683691
if "info" in state and hasattr(state["info"], "oai_tools"):
684692
state["oai_tools"] = state["info"]["oai_tools"]
@@ -845,7 +853,6 @@ async def generate(
845853
results_path: Path | None = None,
846854
state_columns: list[str] | None = None,
847855
save_results: bool = False,
848-
save_every: int = -1,
849856
push_to_hf_hub: bool = False,
850857
hf_hub_dataset_name: str | None = None,
851858
use_tqdm: bool = True,
@@ -865,6 +872,10 @@ async def generate(
865872
elif isinstance(inputs, list):
866873
inputs_list = inputs
867874

875+
if not inputs_list:
876+
self.logger.info("No inputs to generate")
877+
sys.exit(0)
878+
868879
# notify caller of actual total count (useful when num_examples=-1)
869880
if on_start is not None:
870881
on_start(len(inputs_list))
@@ -879,7 +890,7 @@ async def generate(
879890
else:
880891
sampling_args = default_sampling_args
881892

882-
# Initialize builder for incremental serialization
893+
# initialize generate outputs builder
883894
builder = GenerateOutputsBuilder(
884895
env_id=self.env_id,
885896
env_args=self.env_args,
@@ -947,15 +958,14 @@ async def generate(
947958
# process tasks as they complete
948959
reward_sum, reward_count = 0, 0
949960
groups_or_rollouts_completed = 0
961+
outputs: list[RolloutOutput] = []
950962
try:
951963
for coro in asyncio.as_completed(tasks.keys()):
952964
result = await coro
953965

954966
# normalize: independent_scoring returns RolloutOutput, group returns list[RolloutOutput]
955-
outputs = [result] if independent_scoring else result
956-
957-
# Serialize states to outputs immediately (serialization happens once here)
958-
new_outputs = builder.add_outputs(outputs)
967+
new_outputs = [result] if independent_scoring else result
968+
builder.add_outputs(new_outputs)
959969
groups_or_rollouts_completed += 1
960970

961971
# track reward for rolling average (from outputs)
@@ -971,19 +981,15 @@ async def generate(
971981
if reward_count > 0:
972982
pbar.set_postfix(reward=f"{reward_sum / reward_count:.3f}")
973983
elif on_progress is not None:
974-
on_progress(builder.outputs, new_outputs)
975-
976-
# save intermediate results (outputs already serialized, no redundant work)
977-
if (
978-
save_results
979-
and save_every > 0
980-
and groups_or_rollouts_completed % save_every == 0
981-
):
982-
intermediate_results = builder.build()
984+
on_progress(outputs, new_outputs)
985+
986+
if save_results:
987+
# incrementally save outputs
988+
save_new_outputs(new_outputs, builder.results_path)
989+
save_metadata(builder.build_metadata(), builder.results_path)
983990
self.logger.debug(
984-
f"Saving intermediate results to {intermediate_results['metadata']['path_to_save']}"
991+
f"Saved {len(new_outputs)} new outputs to {builder.results_path}"
985992
)
986-
save_generate_outputs(intermediate_results)
987993
finally:
988994
# cancel all outstanding tasks and await their completion
989995
pending = [task for task in tasks.keys() if not task.done()]
@@ -999,7 +1005,10 @@ async def generate(
9991005

10001006
# save if requested
10011007
if save_results:
1002-
save_generate_outputs(results, push_to_hf_hub, hf_hub_dataset_name)
1008+
save_outputs(results["outputs"], builder.results_path)
1009+
save_metadata(results["metadata"], builder.results_path)
1010+
if push_to_hf_hub:
1011+
push_results_to_hf_hub(results, hf_hub_dataset_name)
10031012
if on_log is not None:
10041013
on_log(f"Saved final outputs to {results['metadata']['path_to_save']}")
10051014

@@ -1063,7 +1072,7 @@ async def evaluate(
10631072
results_path: Path | None = None,
10641073
state_columns: list[str] | None = None,
10651074
save_results: bool = False,
1066-
save_every: int = -1,
1075+
resume_path: Path | None = None,
10671076
push_to_hf_hub: bool = False,
10681077
hf_hub_dataset_name: str | None = None,
10691078
use_tqdm: bool = True,
@@ -1078,6 +1087,8 @@ async def evaluate(
10781087
Evaluate model on the Environment evaluation dataset.
10791088
"""
10801089
inputs = self._get_eval_inputs(num_examples, rollouts_per_example)
1090+
if resume_path is not None:
1091+
inputs = filter_inputs(inputs, resume_path, rollouts_per_example)
10811092
return await self.generate(
10821093
inputs,
10831094
client=client,
@@ -1087,7 +1098,6 @@ async def evaluate(
10871098
results_path=results_path,
10881099
state_columns=state_columns,
10891100
save_results=save_results,
1090-
save_every=save_every,
10911101
push_to_hf_hub=push_to_hf_hub,
10921102
hf_hub_dataset_name=hf_hub_dataset_name,
10931103
use_tqdm=use_tqdm,
@@ -1110,7 +1120,7 @@ def evaluate_sync(
11101120
results_path: Path | None = None,
11111121
state_columns: list[str] | None = None,
11121122
save_results: bool = False,
1113-
save_every: int = -1,
1123+
resume_path: Path | None = None,
11141124
push_to_hf_hub: bool = False,
11151125
hf_hub_dataset_name: str | None = None,
11161126
independent_scoring: bool = False,
@@ -1129,7 +1139,7 @@ def evaluate_sync(
11291139
results_path=results_path,
11301140
state_columns=state_columns,
11311141
save_results=save_results,
1132-
save_every=save_every,
1142+
resume_path=resume_path,
11331143
push_to_hf_hub=push_to_hf_hub,
11341144
hf_hub_dataset_name=hf_hub_dataset_name,
11351145
independent_scoring=independent_scoring,

verifiers/scripts/eval.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from verifiers.utils.path_utils import is_valid_eval_results_path
4+
35
# Suppress tokenizers parallelism warning (only prints when env var is unset)
46
os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")
57

@@ -212,15 +214,13 @@ def main():
212214
help="Save results to disk",
213215
)
214216
parser.add_argument(
215-
"--save-every",
216-
"-f",
217-
type=int,
218-
default=DEFAULT_SAVE_EVERY,
219-
help="Save dataset every n rollouts (-1 to disable)",
217+
"--resume-path",
218+
type=str,
219+
default=None,
220+
help="Resume from a previous run.",
220221
)
221222
parser.add_argument(
222223
"--independent-scoring",
223-
"-R",
224224
default=False,
225225
action="store_true",
226226
help="Score each rollout individually instead of scoring by group",
@@ -389,6 +389,16 @@ def build_eval_config(raw: dict) -> EvalConfig:
389389
extra_headers=merged_headers,
390390
)
391391

392+
# handle resume path resolution
393+
resume_path = raw.get("resume_path")
394+
if resume_path is not None:
395+
resume_path = Path(resume_path)
396+
if not is_valid_eval_results_path(resume_path):
397+
raise ValueError(
398+
f"Resume path {resume_path} is not a valid evaluation results path"
399+
)
400+
logger.info(f"Resuming from: {resume_path}")
401+
392402
return EvalConfig(
393403
env_id=env_id,
394404
env_args=raw.get("env_args", {}),
@@ -404,7 +414,7 @@ def build_eval_config(raw: dict) -> EvalConfig:
404414
verbose=raw.get("verbose", False),
405415
state_columns=raw.get("state_columns", []),
406416
save_results=raw.get("save_results", False),
407-
save_every=raw.get("save_every", DEFAULT_SAVE_EVERY),
417+
resume_path=resume_path,
408418
independent_scoring=raw.get("independent_scoring", False),
409419
save_to_hf_hub=raw.get("save_to_hf_hub", False),
410420
hf_hub_dataset_name=raw.get("hf_hub_dataset_name", ""),

verifiers/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ class EvalConfig(BaseModel):
272272
# saving
273273
state_columns: list[str] | None = None
274274
save_results: bool = False
275-
save_every: int = -1
275+
resume_path: Path | None = None
276276
save_to_hf_hub: bool = False
277277
hf_hub_dataset_name: str | None = None
278278

verifiers/utils/eval_display.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,6 @@ def fmt_concurrency(val: int) -> str:
318318
if config.save_results:
319319
config_line.append(" | ", style="dim")
320320
config_line.append("saving results", style="white")
321-
if config.save_every > 0:
322-
config_line.append(" every ", style="dim")
323-
config_line.append(str(config.save_every), style="white")
324-
config_line.append(" steps", style="dim")
325321

326322
# create progress bar with timing
327323
# use env_state.total which gets updated by on_start callback

verifiers/utils/eval_utils.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from datasets import disable_progress_bar, enable_progress_bar
1414
from datasets.utils import logging as ds_logging
1515

16+
from verifiers.utils.save_utils import load_outputs
17+
1618
try:
1719
import tomllib # type: ignore[import-not-found]
1820
except ImportError:
@@ -31,6 +33,7 @@
3133
GenerateOutputs,
3234
LogCallback,
3335
ProgressCallback,
36+
RolloutInput,
3437
RolloutOutput,
3538
StartCallback,
3639
)
@@ -181,6 +184,29 @@ def load_toml_config(path: Path) -> list[dict]:
181184
return merged_eval_list
182185

183186

187+
def filter_inputs(
188+
inputs: list[RolloutInput], results_path: Path, rollouts_per_example: int
189+
):
190+
"""Filter inputs based on the number of rollouts per example."""
191+
saved_outputs = load_outputs(results_path)
192+
193+
inputs_by_example_id, outputs_by_example_id = defaultdict(list), defaultdict(list)
194+
for input in inputs:
195+
inputs_by_example_id[input["example_id"]].append(input)
196+
for output in saved_outputs:
197+
outputs_by_example_id[output["example_id"]].append(output)
198+
199+
filtered_inputs = []
200+
for example_id in inputs_by_example_id.keys():
201+
example_inputs = inputs_by_example_id[example_id]
202+
example_outputs = outputs_by_example_id[example_id]
203+
rollouts_left = len(example_outputs) - rollouts_per_example
204+
if rollouts_left > 0:
205+
filtered_inputs.extend(example_inputs[:rollouts_per_example])
206+
207+
return filtered_inputs
208+
209+
184210
def to_col_order(list_of_dicts: list[Mapping[str, float]]) -> dict[str, list[float]]:
185211
"""Convert a list of mappings to a dictionary of lists, ordered by the keys of the first mapping."""
186212
if not list_of_dicts:
@@ -339,7 +365,7 @@ async def run_evaluation(
339365
await vf_env.start_server(extra_env_kwargs=config.extra_env_kwargs)
340366

341367
# run evaluation
342-
results_path = get_eval_results_path(config)
368+
results_path = config.resume_path or get_eval_results_path(config)
343369
logger.debug(f"Starting evaluation with model: {config.model}")
344370
logger.debug(
345371
f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}"
@@ -356,7 +382,7 @@ async def run_evaluation(
356382
results_path=results_path,
357383
state_columns=config.state_columns,
358384
save_results=config.save_results,
359-
save_every=config.save_every,
385+
resume_path=config.resume_path,
360386
push_to_hf_hub=config.save_to_hf_hub,
361387
hf_hub_dataset_name=config.hf_hub_dataset_name,
362388
use_tqdm=use_tqdm,

verifiers/utils/path_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import logging
12
import uuid
23
from pathlib import Path
34

45
from verifiers.types import EvalConfig
56

7+
logger = logging.getLogger(__name__)
8+
69

710
def get_results_path(
811
env_id: str,
@@ -28,6 +31,16 @@ def get_eval_results_path(config: EvalConfig) -> Path:
2831
return results_path
2932

3033

34+
def is_valid_eval_results_path(path: Path) -> bool:
35+
"""Checks if a path is a valid evaluation results path."""
36+
return (
37+
path.exists()
38+
and path.is_dir()
39+
and Path(path / "results.jsonl").exists()
40+
and Path(path / "metadata.json").exists()
41+
)
42+
43+
3144
def get_gepa_results_path(
3245
env_id: str,
3346
model: str,

0 commit comments

Comments
 (0)