forked from Mercidaiha/IRT-Router
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreformat_csv.py
More file actions
147 lines (115 loc) · 4.97 KB
/
reformat_csv.py
File metadata and controls
147 lines (115 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python3
"""
Reformat a CSV file from the multi-LLM-per-row format (like sap_facility_custom_routing_merged.csv)
to the one-LLM-per-row format (like router_arena_data/test2.csv).
Source format:
- Columns: prompt_id, prompt_id, prompt, {model}/response, {model}/score, ...
- One row per prompt with multiple model response/score columns
Target format:
- Columns: id, original_id, question, ground_truth, completion, input_tokens, output_tokens, cost, performance, task, llm
- One row per model response, with original_id linking rows from the same prompt
"""
import argparse
import csv
import re
import sys
from pathlib import Path
csv.field_size_limit(sys.maxsize)
def extract_model_columns(headers: list[str]) -> list[tuple[str, str, str]]:
"""
Extract model names and their response/score column indices.
Returns list of tuples: (model_name, response_col, score_col)
"""
models = []
response_pattern = re.compile(r'^(.+)/response$')
for header in headers:
match = response_pattern.match(header)
if match:
model_name = match.group(1)
score_col = f"{model_name}/score"
if score_col in headers:
models.append((model_name, header, score_col))
return models
def reformat_csv(input_file: str, output_file: str) -> None:
"""
Reformat CSV from multi-LLM-per-row to one-LLM-per-row format.
"""
input_path = Path(input_file)
output_path = Path(output_file)
if not input_path.exists():
print(f"Error: Input file '{input_file}' not found.")
sys.exit(1)
with open(input_path, 'r', newline='', encoding='utf-8') as infile:
reader = csv.DictReader(infile)
headers = reader.fieldnames
if not headers:
print("Error: Could not read headers from input file.")
sys.exit(1)
# Find the prompt column (could be 'prompt' or 'question')
prompt_col = None
for col in ['prompt', 'question']:
if col in headers:
prompt_col = col
break
if not prompt_col:
print("Error: Could not find 'prompt' or 'question' column in input file.")
sys.exit(1)
# Extract model columns
models = extract_model_columns(headers)
if not models:
print("Error: Could not find any model response/score column pairs.")
print(f"Headers found: {headers}")
sys.exit(1)
print(f"Found {len(models)} models:")
for model_name, _, _ in models:
print(f" - {model_name}")
# Output columns matching target format
output_headers = [
'id', 'original_id', 'question', 'ground_truth', 'completion',
'input_tokens', 'output_tokens', 'cost', 'performance', 'task', 'llm'
]
with open(output_path, 'w', newline='', encoding='utf-8') as outfile:
writer = csv.DictWriter(outfile, fieldnames=output_headers)
writer.writeheader()
row_id = 0
original_id = 0
for row in reader:
prompt = row.get(prompt_col, '')
# Create one output row per model
for model_name, response_col, score_col in models:
response = row.get(response_col, '')
score = row.get(score_col, '')
output_row = {
'id': row_id,
'original_id': original_id,
'question': prompt,
'ground_truth': 'none',
'completion': response,
'input_tokens': 'none',
'output_tokens': 'none',
'cost': 'none',
'performance': score,
'task': 'none',
'llm': model_name
}
writer.writerow(output_row)
row_id += 1
original_id += 1
print(f"\nReformatted {original_id} prompts into {row_id} rows.")
print(f"Output written to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description='Reformat CSV from multi-LLM-per-row to one-LLM-per-row format.',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
Examples:
python reformat_csv.py input.csv output.csv
python reformat_csv.py sap_facility_data/sap_facility_custom_routing_merged.csv reformatted_output.csv
'''
)
parser.add_argument('--input-file', help='Input CSV file (multi-LLM-per-row format)')
parser.add_argument('--output-file', help='Output CSV file (one-LLM-per-row format)')
args = parser.parse_args()
reformat_csv(args.input_file, args.output_file)
if __name__ == '__main__':
main()