forked from Mercidaiha/IRT-Router
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoracle_routing_analysis.py
More file actions
76 lines (58 loc) · 2.5 KB
/
oracle_routing_analysis.py
File metadata and controls
76 lines (58 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python3
"""
Oracle Routing Analysis
Analyzes the maximum possible accuracy of a router that always chooses
one of the correct models (if one exists) for each prompt.
For each unique original_id:
- If at least one model has performance=1.0, the oracle router is "correct"
- If no model has performance=1.0, the oracle router is "wrong"
"""
import pandas as pd
import sys
def analyze_oracle_routing(csv_path: str) -> None:
print(f"Loading data from: {csv_path}")
# Read the CSV file
df = pd.read_csv(csv_path)
print(f"Total rows: {len(df)}")
print(f"Columns: {list(df.columns)}")
# Get unique original_ids
unique_ids = df['original_id'].unique()
print(f"Unique prompts (original_id): {len(unique_ids)}")
# Count unique models (llm column)
unique_models = df['llm'].unique()
print(f"Unique models: {len(unique_models)}")
print(f"Models: {list(unique_models)}")
# Group by original_id and check if any has performance=1.0
oracle_correct = 0
oracle_wrong = 0
for orig_id in unique_ids:
group = df[df['original_id'] == orig_id]
max_perf = group['performance'].max()
if max_perf == 1.0:
oracle_correct += 1
else:
oracle_wrong += 1
total_prompts = oracle_correct + oracle_wrong
oracle_accuracy = (oracle_correct / total_prompts) * 100
print("\n" + "="*50)
print("ORACLE ROUTING ANALYSIS RESULTS")
print("="*50)
print(f"Total unique prompts: {total_prompts}")
print(f"Prompts with at least one correct model: {oracle_correct}")
print(f"Prompts with no correct model: {oracle_wrong}")
print(f"\nOracle Routing Accuracy: {oracle_accuracy:.2f}%")
print("="*50)
# Additional analysis: distribution of how many models are correct per prompt
print("\nAdditional Analysis:")
correct_counts = []
for orig_id in unique_ids:
group = df[df['original_id'] == orig_id]
correct_count = (group['performance'] == 1.0).sum()
correct_counts.append(correct_count)
correct_counts_series = pd.Series(correct_counts)
print(f"\nDistribution of correct models per prompt:")
for count, freq in sorted(correct_counts_series.value_counts().items()):
print(f" {count} correct model(s): {freq} prompts ({freq/total_prompts*100:.1f}%)")
if __name__ == "__main__":
csv_path = sys.argv[1] if len(sys.argv) > 1 else "new_data/routerarena_irtrouter_format.csv"
analyze_oracle_routing(csv_path)