77import json
88import logging
99import signal
10+ import sys
1011import time
1112import uuid
1213from abc import ABC , abstractmethod
2829
2930from openai import AsyncOpenAI , BadRequestError , OpenAI
3031
32+ from verifiers .utils .eval_utils import filter_inputs
3133from verifiers .utils .worker_utils import get_free_port
3234from verifiers .workers .client .zmq_env_client import ZMQEnvClient
3335from verifiers .workers .server .zmq_env_server import ZMQEnvServer
7173from verifiers .utils .save_utils import (
7274 GenerateOutputsBuilder ,
7375 make_dataset ,
74- save_generate_outputs ,
76+ push_results_to_hf_hub ,
77+ save_metadata ,
78+ save_new_outputs ,
79+ save_outputs ,
7580 state_to_output ,
7681)
7782from verifiers .utils .token_utils import (
@@ -673,12 +678,15 @@ async def init_state(
673678 state_input ["info" ] = json .loads (state_input ["info" ])
674679 if "task" not in state_input :
675680 state_input ["task" ] = self .env_id or "default"
681+ # Extract rollout_idx before creating RolloutInput (it's not part of RolloutInput)
682+ rollout_idx = state_input .pop ("rollout_idx" , 0 )
676683 state = State (input = RolloutInput (** state_input )) # type: ignore[missing-typed-dict-key]
677684 state ["client" ] = client
678685 state ["model" ] = model
679686 state ["sampling_args" ] = sampling_args
680687 state ["is_completed" ] = False
681688 state ["is_truncated" ] = False
689+ state ["rollout_idx" ] = rollout_idx
682690 state ["oai_tools" ] = None
683691 if "info" in state and hasattr (state ["info" ], "oai_tools" ):
684692 state ["oai_tools" ] = state ["info" ]["oai_tools" ]
@@ -845,7 +853,6 @@ async def generate(
845853 results_path : Path | None = None ,
846854 state_columns : list [str ] | None = None ,
847855 save_results : bool = False ,
848- save_every : int = - 1 ,
849856 push_to_hf_hub : bool = False ,
850857 hf_hub_dataset_name : str | None = None ,
851858 use_tqdm : bool = True ,
@@ -865,6 +872,10 @@ async def generate(
865872 elif isinstance (inputs , list ):
866873 inputs_list = inputs
867874
875+ if not inputs_list :
876+ self .logger .info ("No inputs to generate" )
877+ sys .exit (0 )
878+
868879 # notify caller of actual total count (useful when num_examples=-1)
869880 if on_start is not None :
870881 on_start (len (inputs_list ))
@@ -879,7 +890,7 @@ async def generate(
879890 else :
880891 sampling_args = default_sampling_args
881892
882- # Initialize builder for incremental serialization
893+ # initialize generate outputs builder
883894 builder = GenerateOutputsBuilder (
884895 env_id = self .env_id ,
885896 env_args = self .env_args ,
@@ -947,15 +958,14 @@ async def generate(
947958 # process tasks as they complete
948959 reward_sum , reward_count = 0 , 0
949960 groups_or_rollouts_completed = 0
961+ outputs : list [RolloutOutput ] = []
950962 try :
951963 for coro in asyncio .as_completed (tasks .keys ()):
952964 result = await coro
953965
954966 # normalize: independent_scoring returns RolloutOutput, group returns list[RolloutOutput]
955- outputs = [result ] if independent_scoring else result
956-
957- # Serialize states to outputs immediately (serialization happens once here)
958- new_outputs = builder .add_outputs (outputs )
967+ new_outputs = [result ] if independent_scoring else result
968+ builder .add_outputs (new_outputs )
959969 groups_or_rollouts_completed += 1
960970
961971 # track reward for rolling average (from outputs)
@@ -971,19 +981,15 @@ async def generate(
971981 if reward_count > 0 :
972982 pbar .set_postfix (reward = f"{ reward_sum / reward_count :.3f} " )
973983 elif on_progress is not None :
974- on_progress (builder .outputs , new_outputs )
975-
976- # save intermediate results (outputs already serialized, no redundant work)
977- if (
978- save_results
979- and save_every > 0
980- and groups_or_rollouts_completed % save_every == 0
981- ):
982- intermediate_results = builder .build ()
984+ on_progress (outputs , new_outputs )
985+
986+ if save_results :
987+ # incrementally save outputs
988+ save_new_outputs (new_outputs , builder .results_path )
989+ save_metadata (builder .build_metadata (), builder .results_path )
983990 self .logger .debug (
984- f"Saving intermediate results to { intermediate_results [ 'metadata' ][ 'path_to_save' ] } "
991+ f"Saved { len ( new_outputs ) } new outputs to { builder . results_path } "
985992 )
986- save_generate_outputs (intermediate_results )
987993 finally :
988994 # cancel all outstanding tasks and await their completion
989995 pending = [task for task in tasks .keys () if not task .done ()]
@@ -999,7 +1005,10 @@ async def generate(
9991005
10001006 # save if requested
10011007 if save_results :
1002- save_generate_outputs (results , push_to_hf_hub , hf_hub_dataset_name )
1008+ save_outputs (results ["outputs" ], builder .results_path )
1009+ save_metadata (results ["metadata" ], builder .results_path )
1010+ if push_to_hf_hub :
1011+ push_results_to_hf_hub (results , hf_hub_dataset_name )
10031012 if on_log is not None :
10041013 on_log (f"Saved final outputs to { results ['metadata' ]['path_to_save' ]} " )
10051014
@@ -1063,7 +1072,7 @@ async def evaluate(
10631072 results_path : Path | None = None ,
10641073 state_columns : list [str ] | None = None ,
10651074 save_results : bool = False ,
1066- save_every : int = - 1 ,
1075+ resume_path : Path | None = None ,
10671076 push_to_hf_hub : bool = False ,
10681077 hf_hub_dataset_name : str | None = None ,
10691078 use_tqdm : bool = True ,
@@ -1078,6 +1087,8 @@ async def evaluate(
10781087 Evaluate model on the Environment evaluation dataset.
10791088 """
10801089 inputs = self ._get_eval_inputs (num_examples , rollouts_per_example )
1090+ if resume_path is not None :
1091+ inputs = filter_inputs (inputs , resume_path , rollouts_per_example )
10811092 return await self .generate (
10821093 inputs ,
10831094 client = client ,
@@ -1087,7 +1098,6 @@ async def evaluate(
10871098 results_path = results_path ,
10881099 state_columns = state_columns ,
10891100 save_results = save_results ,
1090- save_every = save_every ,
10911101 push_to_hf_hub = push_to_hf_hub ,
10921102 hf_hub_dataset_name = hf_hub_dataset_name ,
10931103 use_tqdm = use_tqdm ,
@@ -1110,7 +1120,7 @@ def evaluate_sync(
11101120 results_path : Path | None = None ,
11111121 state_columns : list [str ] | None = None ,
11121122 save_results : bool = False ,
1113- save_every : int = - 1 ,
1123+ resume_path : Path | None = None ,
11141124 push_to_hf_hub : bool = False ,
11151125 hf_hub_dataset_name : str | None = None ,
11161126 independent_scoring : bool = False ,
@@ -1129,7 +1139,7 @@ def evaluate_sync(
11291139 results_path = results_path ,
11301140 state_columns = state_columns ,
11311141 save_results = save_results ,
1132- save_every = save_every ,
1142+ resume_path = resume_path ,
11331143 push_to_hf_hub = push_to_hf_hub ,
11341144 hf_hub_dataset_name = hf_hub_dataset_name ,
11351145 independent_scoring = independent_scoring ,
0 commit comments