Skip to content

Commit 0d96a23

Browse files
committed
Adding Model Evaluation Statistics
1 parent 93378a7 commit 0d96a23

2 files changed

Lines changed: 330 additions & 0 deletions

File tree

XPointMLTest.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
# Import benchmark module
2929
from benchmark import TrainingBenchmark
3030

31+
# Import evaluation metrics module
32+
from eval_metrics import ModelEvaluator, evaluate_model_on_dataset
33+
3134
def expand_xpoints_mask(binary_mask, kernel_size=9):
3235
"""
3336
Expands each X-point in a binary mask to include surrounding cells
@@ -768,6 +771,8 @@ def parseCommandLineArgs():
768771
help='enable performance benchmarking (tracks timing, throughput, GPU memory)')
769772
parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json',
770773
help='path to save benchmark results JSON file (default: ./benchmark_results.json)')
774+
parser.add_argument('--eval-output', type=Path, default='./evaluation_metrics.json',
775+
help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)')
771776

772777
# CI TEST: Add smoke test flag
773778
parser.add_argument('--smoke-test', action='store_true',
@@ -1172,6 +1177,65 @@ def main():
11721177
print("Loading best model for evaluation...")
11731178
model.load_state_dict(torch.load(best_model_path, weights_only=True))
11741179

1180+
# new evaluation code
1181+
# Evaluate model performance
1182+
if not args.smoke_test:
1183+
# print("\n" + "="*70)
1184+
# print("RUNNING MODEL EVALUATION")
1185+
# print("="*70)
1186+
1187+
# # Evaluate on validation set
1188+
# print("\nEvaluating on validation set...")
1189+
val_evaluator = evaluate_model_on_dataset(
1190+
model,
1191+
val_dataset, # Use original dataset, not patch dataset
1192+
device,
1193+
use_amp=use_amp,
1194+
amp_dtype=amp_dtype,
1195+
threshold=0.5
1196+
)
1197+
1198+
# Print and save validation metrics
1199+
val_evaluator.print_summary()
1200+
val_evaluator.save_json(args.eval_output)
1201+
1202+
# Evaluate on training set
1203+
print("\nEvaluating on training set...")
1204+
train_evaluator = evaluate_model_on_dataset(
1205+
model,
1206+
train_dataset,
1207+
device,
1208+
use_amp=use_amp,
1209+
amp_dtype=amp_dtype,
1210+
threshold=0.5
1211+
)
1212+
1213+
# Print and save training metrics
1214+
train_evaluator.print_summary()
1215+
train_eval_path = args.eval_output.parent / f"train_{args.eval_output.name}"
1216+
train_evaluator.save_json(train_eval_path)
1217+
1218+
# Compare training vs validation to check for overfitting
1219+
train_global = train_evaluator.get_global_metrics()
1220+
val_global = val_evaluator.get_global_metrics()
1221+
1222+
print("\n" + "="*70)
1223+
print("OVERFITTING CHECK")
1224+
print("="*70)
1225+
print(f"Training F1: {train_global['f1_score']:.4f}")
1226+
print(f"Validation F1: {val_global['f1_score']:.4f}")
1227+
print(f"Difference: {abs(train_global['f1_score'] - val_global['f1_score']):.4f}")
1228+
1229+
if train_global['f1_score'] - val_global['f1_score'] > 0.05:
1230+
print("⚠ Warning: Possible overfitting detected (train F1 >> val F1)")
1231+
elif val_global['f1_score'] - train_global['f1_score'] > 0.05:
1232+
print("⚠ Warning: Unusual pattern (val F1 >> train F1)")
1233+
else:
1234+
print("✓ Model generalizes well to validation set")
1235+
print("="*70 + "\n")
1236+
1237+
# ==================== END NEW EVALUATION CODE ====================
1238+
11751239
# (D) Plotting after training
11761240
model.eval() # switch to inference mode
11771241
outDir = "plots"

eval_metrics.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""
2+
Model evaluation metrics for X-point detection.
3+
4+
This module provides functions to compute detailed performance metrics
5+
for the X-point detection model, including per-frame and global statistics.
6+
"""
7+
8+
import numpy as np
9+
import json
10+
from pathlib import Path
11+
import torch
12+
from torch.amp import autocast
13+
14+
15+
class ModelEvaluator:
16+
"""
17+
Evaluates model performance on X-point detection task.
18+
19+
Computes metrics including:
20+
- True Positives (TP): X-point pixels correctly identified
21+
- False Positives (FP): Background pixels incorrectly labeled as X-points
22+
- False Negatives (FN): X-point pixels that were missed
23+
- True Negatives (TN): Background pixels correctly identified
24+
25+
Metrics calculated:
26+
- Accuracy: (TP + TN) / (TP + TN + FP + FN)
27+
- Precision: TP / (TP + FP)
28+
- Recall: TP / (TP + FN)
29+
- F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
30+
- IoU: TP / (TP + FP + FN)
31+
"""
32+
33+
def __init__(self, threshold=0.5):
34+
"""
35+
Initialize evaluator.
36+
37+
Parameters:
38+
threshold: float - Probability threshold for binary classification (default: 0.5)
39+
"""
40+
self.threshold = threshold
41+
self.reset()
42+
43+
def reset(self):
44+
"""Reset all accumulated metrics."""
45+
self.global_tp = 0
46+
self.global_fp = 0
47+
self.global_fn = 0
48+
self.global_tn = 0
49+
self.frame_metrics = []
50+
51+
def compute_frame_metrics(self, pred_probs, ground_truth):
52+
"""
53+
Compute metrics for a single frame.
54+
55+
Parameters:
56+
pred_probs: np.ndarray - Predicted probabilities, shape [H, W]
57+
ground_truth: np.ndarray - Ground truth binary mask, shape [H, W]
58+
59+
Returns:
60+
dict - Dictionary containing TP, FP, FN, TN and derived metrics
61+
"""
62+
# Binarize predictions
63+
pred_binary = (pred_probs > self.threshold).astype(np.float32)
64+
gt_binary = (ground_truth > 0.5).astype(np.float32)
65+
66+
# Compute confusion matrix elements
67+
tp = np.sum((pred_binary == 1) & (gt_binary == 1))
68+
fp = np.sum((pred_binary == 1) & (gt_binary == 0))
69+
fn = np.sum((pred_binary == 0) & (gt_binary == 1))
70+
tn = np.sum((pred_binary == 0) & (gt_binary == 0))
71+
72+
# Compute derived metrics
73+
total = tp + fp + fn + tn
74+
accuracy = (tp + tn) / total if total > 0 else 0.0
75+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
76+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
77+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
78+
iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
79+
80+
return {
81+
'tp': int(tp),
82+
'fp': int(fp),
83+
'fn': int(fn),
84+
'tn': int(tn),
85+
'accuracy': float(accuracy),
86+
'precision': float(precision),
87+
'recall': float(recall),
88+
'f1_score': float(f1),
89+
'iou': float(iou)
90+
}
91+
92+
def add_frame(self, pred_probs, ground_truth, frame_id=None):
93+
"""
94+
Add a frame's results to the evaluation.
95+
96+
Parameters:
97+
pred_probs: np.ndarray - Predicted probabilities
98+
ground_truth: np.ndarray - Ground truth binary mask
99+
frame_id: int or str - Optional frame identifier
100+
"""
101+
metrics = self.compute_frame_metrics(pred_probs, ground_truth)
102+
103+
# Add to global counts
104+
self.global_tp += metrics['tp']
105+
self.global_fp += metrics['fp']
106+
self.global_fn += metrics['fn']
107+
self.global_tn += metrics['tn']
108+
109+
# Store frame metrics
110+
if frame_id is not None:
111+
metrics['frame_id'] = frame_id
112+
self.frame_metrics.append(metrics)
113+
114+
def get_global_metrics(self):
115+
"""
116+
Compute global metrics across all frames.
117+
118+
Returns:
119+
dict - Global metrics computed from accumulated confusion matrix
120+
"""
121+
total = self.global_tp + self.global_fp + self.global_fn + self.global_tn
122+
123+
metrics = {
124+
'global_tp': int(self.global_tp),
125+
'global_fp': int(self.global_fp),
126+
'global_fn': int(self.global_fn),
127+
'global_tn': int(self.global_tn),
128+
'total_pixels': int(total),
129+
'accuracy': (self.global_tp + self.global_tn) / total if total > 0 else 0.0,
130+
'precision': self.global_tp / (self.global_tp + self.global_fp)
131+
if (self.global_tp + self.global_fp) > 0 else 0.0,
132+
'recall': self.global_tp / (self.global_tp + self.global_fn)
133+
if (self.global_tp + self.global_fn) > 0 else 0.0,
134+
'iou': self.global_tp / (self.global_tp + self.global_fp + self.global_fn)
135+
if (self.global_tp + self.global_fp + self.global_fn) > 0 else 0.0,
136+
}
137+
138+
# Compute F1 from global precision and recall
139+
if (metrics['precision'] + metrics['recall']) > 0:
140+
metrics['f1_score'] = 2 * metrics['precision'] * metrics['recall'] / \
141+
(metrics['precision'] + metrics['recall'])
142+
else:
143+
metrics['f1_score'] = 0.0
144+
145+
return metrics
146+
147+
def get_frame_statistics(self):
148+
"""
149+
Compute statistics across all frames.
150+
151+
Returns:
152+
dict - Mean and standard deviation for each metric across frames
153+
"""
154+
if not self.frame_metrics:
155+
return {}
156+
157+
metrics_arrays = {
158+
key: np.array([frame[key] for frame in self.frame_metrics])
159+
for key in ['accuracy', 'precision', 'recall', 'f1_score', 'iou']
160+
}
161+
162+
stats = {}
163+
for metric_name, values in metrics_arrays.items():
164+
stats[f'{metric_name}_mean'] = float(np.mean(values))
165+
stats[f'{metric_name}_std'] = float(np.std(values))
166+
stats[f'{metric_name}_min'] = float(np.min(values))
167+
stats[f'{metric_name}_max'] = float(np.max(values))
168+
169+
return stats
170+
171+
def print_summary(self):
172+
"""Print comprehensive evaluation summary."""
173+
print("\n" + "="*70)
174+
print("MODEL EVALUATION METRICS")
175+
print("="*70)
176+
177+
global_metrics = self.get_global_metrics()
178+
179+
print("\nGlobal Metrics (across all frames):")
180+
print(f" Total pixels evaluated: {global_metrics['total_pixels']:,}")
181+
print(f" True Positives (TP): {global_metrics['global_tp']:,}")
182+
print(f" False Positives (FP): {global_metrics['global_fp']:,}")
183+
print(f" False Negatives (FN): {global_metrics['global_fn']:,}")
184+
print(f" True Negatives (TN): {global_metrics['global_tn']:,}")
185+
print(f"\n Accuracy: {global_metrics['accuracy']:.4f}")
186+
print(f" Precision: {global_metrics['precision']:.4f}")
187+
print(f" Recall: {global_metrics['recall']:.4f}")
188+
print(f" F1 Score: {global_metrics['f1_score']:.4f}")
189+
print(f" IoU: {global_metrics['iou']:.4f}")
190+
191+
if self.frame_metrics:
192+
print(f"\nPer-Frame Statistics ({len(self.frame_metrics)} frames):")
193+
stats = self.get_frame_statistics()
194+
195+
for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'iou']:
196+
mean = stats[f'{metric}_mean']
197+
std = stats[f'{metric}_std']
198+
min_val = stats[f'{metric}_min']
199+
max_val = stats[f'{metric}_max']
200+
print(f" {metric.replace('_', ' ').title():20s} "
201+
f"mean={mean:.4f} ±{std:.4f} "
202+
f"[{min_val:.4f}, {max_val:.4f}]")
203+
204+
print("="*70 + "\n")
205+
206+
def save_json(self, output_file):
207+
"""
208+
Save evaluation results to JSON file.
209+
210+
Parameters:
211+
output_file: Path or str - File path to save evaluation data
212+
"""
213+
evaluation_data = {
214+
'global_metrics': self.get_global_metrics(),
215+
'frame_statistics': self.get_frame_statistics(),
216+
'per_frame_metrics': self.frame_metrics,
217+
'threshold': self.threshold,
218+
'num_frames': len(self.frame_metrics)
219+
}
220+
221+
output_path = Path(output_file)
222+
output_path.parent.mkdir(parents=True, exist_ok=True)
223+
224+
with open(output_path, 'w') as f:
225+
json.dump(evaluation_data, f, indent=2)
226+
227+
print(f"Evaluation metrics saved to: {output_path}")
228+
229+
230+
def evaluate_model_on_dataset(model, dataset, device, use_amp=False,
231+
amp_dtype=torch.float16, threshold=0.5):
232+
"""
233+
Evaluate model on entire dataset and return metrics.
234+
235+
Parameters:
236+
model: nn.Module - The trained model
237+
dataset: Dataset - Dataset to evaluate on (XPointDataset, not patch dataset)
238+
device: torch.device - Device to run evaluation on
239+
use_amp: bool - Whether to use automatic mixed precision
240+
amp_dtype: torch.dtype - Data type for mixed precision
241+
threshold: float - Threshold for binary classification
242+
243+
Returns:
244+
ModelEvaluator - Evaluator object with computed metrics
245+
"""
246+
model.eval()
247+
evaluator = ModelEvaluator(threshold=threshold)
248+
249+
with torch.no_grad():
250+
for item in dataset:
251+
fnum = item["fnum"]
252+
all_torch = item["all"].unsqueeze(0).to(device)
253+
mask_gt = item["mask"][0].cpu().numpy() # Remove channel dimension
254+
255+
# Get prediction
256+
with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
257+
pred_mask = model(all_torch)
258+
pred_prob = torch.sigmoid(pred_mask)
259+
260+
# Convert to numpy (handle BFloat16)
261+
pred_prob_np = pred_prob[0, 0].float().cpu().numpy()
262+
263+
# Add to evaluator
264+
evaluator.add_frame(pred_prob_np, mask_gt, frame_id=fnum)
265+
266+
return evaluator

0 commit comments

Comments
 (0)