forked from SWE-agent/SWE-agent
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathcompile_predictions.py
More file actions
146 lines (111 loc) · 4.63 KB
/
compile_predictions.py
File metadata and controls
146 lines (111 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python3
"""
Script to create a {model}_preds.json file by collecting prediction dictionaries
from .pred or .patch files in trajectory folders.
Usage:
python collect_predictions.py <base_folder> <model>
Example:
python collect_predictions.py /path/to/sweagent_results/sweap_eval_1023 claude-45sonnet
"""
import argparse
import json
from pathlib import Path
def load_pred_file(pred_file_path, instance_id):
"""Load and parse a .pred or .patch file."""
try:
with open(pred_file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# If it's a .patch file, create a dictionary with the patch content
if pred_file_path.suffix == '.patch':
return {
'instance_id': instance_id,
'model_patch': content
}
# For .pred files, try to parse as JSON first
try:
return json.loads(content)
except json.JSONDecodeError:
# If JSON fails, try to evaluate as Python dict
# Note: This is potentially unsafe with untrusted input
return eval(content)
except Exception as e:
print(f"Error loading {pred_file_path}: {e}")
return None
def collect_predictions(base_folder, model):
"""
Collect all prediction dictionaries from trajectory folders.
Args:
base_folder: Base directory path
model: Model name
Returns:
List of prediction dictionaries
"""
predictions = []
# Build path to trajectory folder
traj_folder = Path(base_folder) / model / "traj"
if not traj_folder.exists():
print(f"Trajectory folder not found: {traj_folder}")
return predictions
print(f"Scanning trajectory folder: {traj_folder}")
# Iterate through all subdirectories in traj folder
for instance_folder in traj_folder.iterdir():
if instance_folder.is_dir():
instance_id = instance_folder.name
# Look for .pred file first, then .patch file
pred_files = list(instance_folder.glob("*.pred"))
if not pred_files:
pred_files = list(instance_folder.glob("*.patch"))
if pred_files:
# Use the first file found
pred_file = pred_files[0]
print(f"Processing: {instance_id} (using {pred_file.suffix})")
# Load the prediction dictionary
pred_dict = load_pred_file(pred_file, instance_id)
if pred_dict is not None:
predictions.append(pred_dict)
else:
print(f"Failed to load prediction from: {pred_file}")
else:
print(f"No .pred or .patch file found in: {instance_folder}")
return predictions
def main():
"""Main function to collect predictions and save to JSON file."""
parser = argparse.ArgumentParser(
description='Collect prediction dictionaries from trajectory folders and save to JSON.',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
%(prog)s /path/to/sweagent_results/sweap_eval_1023 claude-45sonnet
"""
)
parser.add_argument('base_folder', type=str,
help='Base directory containing model results')
parser.add_argument('model', type=str,
help='Model name (subdirectory name within base_folder)')
args = parser.parse_args()
base_folder = Path(args.base_folder)
model = args.model
print(f"Base folder: {base_folder}")
print(f"Model: {model}")
if not base_folder.exists():
print(f"Error: Base folder does not exist: {base_folder}")
return 1
predictions = collect_predictions(base_folder, model)
print(f"\nCollected {len(predictions)} predictions")
if len(predictions) == 0:
print("Warning: No predictions were collected")
return 1
# Use base folder name in output filename
base_name = base_folder.name
output_filename = f"{model}_preds_{base_name}.json"
try:
with open(output_filename, 'w', encoding='utf-8') as f:
json.dump(predictions, f, indent=2, ensure_ascii=False)
print(f"\nSuccessfully created: {output_filename}")
print(f"Contains {len(predictions)} prediction dictionaries")
return 0
except Exception as e:
print(f"Error saving to {output_filename}: {e}")
return 1
if __name__ == "__main__":
exit(main())