-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathgenerate_prediction_file.py
More file actions
304 lines (240 loc) · 9.92 KB
/
generate_prediction_file.py
File metadata and controls
304 lines (240 loc) · 9.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# SPDX-FileCopyrightText: Copyright contributors to the RouterArena project
# SPDX-License-Identifier: Apache-2.0
"""
Generate Prediction File using configured router.
This script generates a prediction file using the router class specified
in the config file's pipeline_params.router_cls_name field.
Usage:
python router_inference/generate_prediction_file.py <router_name> <split>
split: one of "sub_10", "full", or "robustness"
"""
import argparse
import json
import os
import sys
from typing import Dict, Any, List
# Add parent directory to path for imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from router_inference.router import BaseRouter
# Dataset file paths
DATASET_PATHS = {
"sub_10": "./dataset/router_data_10.json",
"full": "./dataset/router_data.json",
"robustness": "./dataset/router_robustness.json",
"gpqa": "./dataset/gpqa_data.json",
}
def load_dataset(split: str) -> List[Dict[str, Any]]:
"""
Load dataset file.
Args:
split: One of the supported dataset splits (sub_10, full, robustness, gpqa)
Returns:
List of dataset entries
"""
dataset_path = DATASET_PATHS.get(split)
if not dataset_path:
raise ValueError(
f"Invalid split: {split}. Must be one of {list(DATASET_PATHS.keys())}"
)
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
with open(dataset_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data
def generate_predictions(
dataset: List[Dict[str, Any]],
router: BaseRouter,
model_pool: List[str],
split: str,
include_optimality: bool = True,
) -> List[Dict[str, Any]]:
"""
Generate predictions using the router, optionally including optimality entries.
For the `full` split:
- Generates 8400 regular entries
- For sub_10 queries within the full split, generates optimality entries for other models
For `robustness`:
- Generates regular entries only (no optimality augmentation)
Args:
dataset: List of dataset entries
router: Router instance to use for predictions
model_pool: List of all models in the router's pool
split: Dataset split ("sub_10", "full", or "robustness")
include_optimality: Whether to include optimality entries (default: True for supported splits)
Returns:
List of prediction dictionaries including optimality entries when applicable
"""
predictions = []
# Only full/sub_10 support optimality augmentation
if split not in {"sub_10", "full"}:
include_optimality = False
# Load sub_10 indices to identify which entries need optimality calculations
sub10_indices = set()
if include_optimality:
try:
sub10_dataset = load_dataset("sub_10")
sub10_indices = {entry.get("global index") for entry in sub10_dataset}
print(
f" Loaded {len(sub10_indices)} sub_10 indices for optimality calculation"
)
except Exception as e:
print(f" Warning: Could not load sub_10 dataset: {e}")
print(" Optimality entries will not be generated")
include_optimality = False
# Track selected models for sub_10 entries
sub10_selected_models = {} # {global_index: (selected_model, prompt)}
# Generate regular entries for all queries
for entry in dataset:
global_index = entry.get("global index")
prompt = entry.get("prompt_formatted") or entry.get("prompt")
if not global_index or not prompt:
continue
# Use the router to get prediction (validation is handled by BaseRouter)
selected_model = router.get_prediction(prompt)
# Track selected model for sub_10 entries (for optimality generation)
if global_index in sub10_indices:
sub10_selected_models[global_index] = (selected_model, prompt)
# Create prediction entry
prediction_entry = {
"global index": global_index,
"prompt": prompt,
"prediction": selected_model,
"generated_result": None,
"cost": None,
"accuracy": None,
"for_optimality": False, # Regular entry
}
predictions.append(prediction_entry)
# Generate optimality entries for sub_10 queries
if include_optimality and sub10_selected_models:
print(
f"\n Generating optimality entries for {len(sub10_selected_models)} sub_10 queries..."
)
optimality_count = 0
for global_index, (selected_model, prompt) in sub10_selected_models.items():
# Generate entries for all OTHER models in pool
other_models = [m for m in model_pool if m != selected_model]
for model in other_models:
optimality_entry = {
"global index": global_index,
"prompt": prompt,
"prediction": model, # Other model, not the one router selected
"generated_result": None,
"cost": None,
"accuracy": None,
"for_optimality": True, # Flag for optimality calculation
}
predictions.append(optimality_entry)
optimality_count += 1
print(f" Generated {optimality_count} optimality entries")
print(
f" Total entries: {len(predictions)} ({len(dataset)} regular + {optimality_count} optimality)"
)
return predictions
def save_predictions(
predictions: List[Dict[str, Any]], router_name: str, split: str
) -> None:
"""
Save predictions to file.
Args:
predictions: List of prediction dictionaries
router_name: Name of the router
"""
filename = router_name
if split == "robustness":
filename = f"{router_name}-robustness"
elif split == "gpqa":
filename = f"{router_name}-gpqa"
prediction_path = f"./router_inference/predictions/{filename}.json"
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(prediction_path), exist_ok=True)
with open(prediction_path, "w", encoding="utf-8") as f:
json.dump(predictions, f, ensure_ascii=False, indent=2)
print(f"✓ Saved {len(predictions)} predictions to {prediction_path}")
def main():
"""Main function to handle command line arguments and generate predictions."""
parser = argparse.ArgumentParser(
description="Generate prediction file using router specified in config"
)
parser.add_argument(
"router_name",
type=str,
help="Name of the router (corresponds to config file)",
)
parser.add_argument(
"split",
type=str,
choices=list(DATASET_PATHS.keys()),
help="Dataset split: 'sub_10', 'full', 'robustness', or 'gpqa'",
)
parser.add_argument(
"--no-optimality",
action="store_true",
help="Skip generating optimality entries (default: include optimality entries)",
)
args = parser.parse_args()
# Change to project root
current_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(current_dir, "../"))
os.chdir(base_dir)
print(f"Generating predictions for router: {args.router_name}")
print(f"Dataset split: {args.split}")
print("=" * 80)
# Load router config first to get router_cls_name
print("\n[1] Loading router config...")
config_path = f"./router_inference/config/{args.router_name}.json"
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
pipeline_params = config.get("pipeline_params", {})
model_pool = pipeline_params.get("models", [])
router_cls_name = pipeline_params.get("router_cls_name", "ExampleRouter")
print(f"✓ Config loaded: {config_path}")
print(f" Router class: {router_cls_name}")
print(f" Model pool: {len(model_pool)} models")
print(f" Models: {', '.join(model_pool)}")
# Initialize router dynamically based on router_cls_name
print("\n[2] Initializing router...")
# Import the router module to access router classes
import router_inference.router as router_module
# Get the router class by name
if not hasattr(router_module, router_cls_name):
raise ValueError(
f"Router class '{router_cls_name}' not found in router_inference.router module. "
f"Available routers: {', '.join([name for name in dir(router_module) if not name.startswith('_')])}"
)
router_cls = getattr(router_module, router_cls_name)
router = router_cls(args.router_name)
print(f"✓ Router initialized: {router.router_name}")
print(f" Available models: {', '.join(router.models)}")
# Load dataset
print("\n[3] Loading dataset...")
dataset = load_dataset(args.split)
print(f"✓ Dataset loaded: {len(dataset)} entries")
# Generate predictions
print("\n[4] Generating predictions...")
include_optimality = not args.no_optimality
optimality_reason = None
if args.no_optimality:
optimality_reason = "--no-optimality flag set"
elif args.split not in {"sub_10", "full"}:
optimality_reason = "not supported for robustness split"
if optimality_reason:
include_optimality = False
print(f" Skipping optimality entries ({optimality_reason})")
else:
print(
" Including optimality entries for automatic optimality score calculation"
)
predictions = generate_predictions(
dataset, router, model_pool, args.split, include_optimality
)
print(f"✓ Generated {len(predictions)} total entries")
# Save predictions
print("\n[5] Saving predictions...")
save_predictions(predictions, args.router_name, args.split)
print("\n" + "=" * 80)
print("✓ Prediction file generation completed!")
print("=" * 80)
return 0
if __name__ == "__main__":
sys.exit(main())