1+ """
2+ Model evaluation metrics for X-point detection.
3+
4+ This module provides functions to compute detailed performance metrics
5+ for the X-point detection model, including per-frame and global statistics.
6+ """
7+
8+ import numpy as np
9+ import json
10+ from pathlib import Path
11+ import torch
12+ from torch .amp import autocast
13+
14+
15+ class ModelEvaluator :
16+ """
17+ Evaluates model performance on X-point detection task.
18+
19+ Computes metrics including:
20+ - True Positives (TP): X-point pixels correctly identified
21+ - False Positives (FP): Background pixels incorrectly labeled as X-points
22+ - False Negatives (FN): X-point pixels that were missed
23+ - True Negatives (TN): Background pixels correctly identified
24+
25+ Metrics calculated:
26+ - Accuracy: (TP + TN) / (TP + TN + FP + FN)
27+ - Precision: TP / (TP + FP)
28+ - Recall: TP / (TP + FN)
29+ - F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
30+ - IoU: TP / (TP + FP + FN)
31+ """
32+
33+ def __init__ (self , threshold = 0.5 ):
34+ """
35+ Initialize evaluator.
36+
37+ Parameters:
38+ threshold: float - Probability threshold for binary classification (default: 0.5)
39+ """
40+ self .threshold = threshold
41+ self .reset ()
42+
43+ def reset (self ):
44+ """Reset all accumulated metrics."""
45+ self .global_tp = 0
46+ self .global_fp = 0
47+ self .global_fn = 0
48+ self .global_tn = 0
49+ self .frame_metrics = []
50+
51+ def compute_frame_metrics (self , pred_probs , ground_truth ):
52+ """
53+ Compute metrics for a single frame.
54+
55+ Parameters:
56+ pred_probs: np.ndarray - Predicted probabilities, shape [H, W]
57+ ground_truth: np.ndarray - Ground truth binary mask, shape [H, W]
58+
59+ Returns:
60+ dict - Dictionary containing TP, FP, FN, TN and derived metrics
61+ """
62+ # Binarize predictions
63+ pred_binary = (pred_probs > self .threshold ).astype (np .float32 )
64+ gt_binary = (ground_truth > 0.5 ).astype (np .float32 )
65+
66+ # Compute confusion matrix elements
67+ tp = np .sum ((pred_binary == 1 ) & (gt_binary == 1 ))
68+ fp = np .sum ((pred_binary == 1 ) & (gt_binary == 0 ))
69+ fn = np .sum ((pred_binary == 0 ) & (gt_binary == 1 ))
70+ tn = np .sum ((pred_binary == 0 ) & (gt_binary == 0 ))
71+
72+ # Compute derived metrics
73+ total = tp + fp + fn + tn
74+ accuracy = (tp + tn ) / total if total > 0 else 0.0
75+ precision = tp / (tp + fp ) if (tp + fp ) > 0 else 0.0
76+ recall = tp / (tp + fn ) if (tp + fn ) > 0 else 0.0
77+ f1 = 2 * precision * recall / (precision + recall ) if (precision + recall ) > 0 else 0.0
78+ iou = tp / (tp + fp + fn ) if (tp + fp + fn ) > 0 else 0.0
79+
80+ return {
81+ 'tp' : int (tp ),
82+ 'fp' : int (fp ),
83+ 'fn' : int (fn ),
84+ 'tn' : int (tn ),
85+ 'accuracy' : float (accuracy ),
86+ 'precision' : float (precision ),
87+ 'recall' : float (recall ),
88+ 'f1_score' : float (f1 ),
89+ 'iou' : float (iou )
90+ }
91+
92+ def add_frame (self , pred_probs , ground_truth , frame_id = None ):
93+ """
94+ Add a frame's results to the evaluation.
95+
96+ Parameters:
97+ pred_probs: np.ndarray - Predicted probabilities
98+ ground_truth: np.ndarray - Ground truth binary mask
99+ frame_id: int or str - Optional frame identifier
100+ """
101+ metrics = self .compute_frame_metrics (pred_probs , ground_truth )
102+
103+ # Add to global counts
104+ self .global_tp += metrics ['tp' ]
105+ self .global_fp += metrics ['fp' ]
106+ self .global_fn += metrics ['fn' ]
107+ self .global_tn += metrics ['tn' ]
108+
109+ # Store frame metrics
110+ if frame_id is not None :
111+ metrics ['frame_id' ] = frame_id
112+ self .frame_metrics .append (metrics )
113+
114+ def get_global_metrics (self ):
115+ """
116+ Compute global metrics across all frames.
117+
118+ Returns:
119+ dict - Global metrics computed from accumulated confusion matrix
120+ """
121+ total = self .global_tp + self .global_fp + self .global_fn + self .global_tn
122+
123+ metrics = {
124+ 'global_tp' : int (self .global_tp ),
125+ 'global_fp' : int (self .global_fp ),
126+ 'global_fn' : int (self .global_fn ),
127+ 'global_tn' : int (self .global_tn ),
128+ 'total_pixels' : int (total ),
129+ 'accuracy' : (self .global_tp + self .global_tn ) / total if total > 0 else 0.0 ,
130+ 'precision' : self .global_tp / (self .global_tp + self .global_fp )
131+ if (self .global_tp + self .global_fp ) > 0 else 0.0 ,
132+ 'recall' : self .global_tp / (self .global_tp + self .global_fn )
133+ if (self .global_tp + self .global_fn ) > 0 else 0.0 ,
134+ 'iou' : self .global_tp / (self .global_tp + self .global_fp + self .global_fn )
135+ if (self .global_tp + self .global_fp + self .global_fn ) > 0 else 0.0 ,
136+ }
137+
138+ # Compute F1 from global precision and recall
139+ if (metrics ['precision' ] + metrics ['recall' ]) > 0 :
140+ metrics ['f1_score' ] = 2 * metrics ['precision' ] * metrics ['recall' ] / \
141+ (metrics ['precision' ] + metrics ['recall' ])
142+ else :
143+ metrics ['f1_score' ] = 0.0
144+
145+ return metrics
146+
147+ def get_frame_statistics (self ):
148+ """
149+ Compute statistics across all frames.
150+
151+ Returns:
152+ dict - Mean and standard deviation for each metric across frames
153+ """
154+ if not self .frame_metrics :
155+ return {}
156+
157+ metrics_arrays = {
158+ key : np .array ([frame [key ] for frame in self .frame_metrics ])
159+ for key in ['accuracy' , 'precision' , 'recall' , 'f1_score' , 'iou' ]
160+ }
161+
162+ stats = {}
163+ for metric_name , values in metrics_arrays .items ():
164+ stats [f'{ metric_name } _mean' ] = float (np .mean (values ))
165+ stats [f'{ metric_name } _std' ] = float (np .std (values ))
166+ stats [f'{ metric_name } _min' ] = float (np .min (values ))
167+ stats [f'{ metric_name } _max' ] = float (np .max (values ))
168+
169+ return stats
170+
171+ def print_summary (self ):
172+ """Print comprehensive evaluation summary."""
173+ print ("\n " + "=" * 70 )
174+ print ("MODEL EVALUATION METRICS" )
175+ print ("=" * 70 )
176+
177+ global_metrics = self .get_global_metrics ()
178+
179+ print ("\n Global Metrics (across all frames):" )
180+ print (f" Total pixels evaluated: { global_metrics ['total_pixels' ]:,} " )
181+ print (f" True Positives (TP): { global_metrics ['global_tp' ]:,} " )
182+ print (f" False Positives (FP): { global_metrics ['global_fp' ]:,} " )
183+ print (f" False Negatives (FN): { global_metrics ['global_fn' ]:,} " )
184+ print (f" True Negatives (TN): { global_metrics ['global_tn' ]:,} " )
185+ print (f"\n Accuracy: { global_metrics ['accuracy' ]:.4f} " )
186+ print (f" Precision: { global_metrics ['precision' ]:.4f} " )
187+ print (f" Recall: { global_metrics ['recall' ]:.4f} " )
188+ print (f" F1 Score: { global_metrics ['f1_score' ]:.4f} " )
189+ print (f" IoU: { global_metrics ['iou' ]:.4f} " )
190+
191+ if self .frame_metrics :
192+ print (f"\n Per-Frame Statistics ({ len (self .frame_metrics )} frames):" )
193+ stats = self .get_frame_statistics ()
194+
195+ for metric in ['accuracy' , 'precision' , 'recall' , 'f1_score' , 'iou' ]:
196+ mean = stats [f'{ metric } _mean' ]
197+ std = stats [f'{ metric } _std' ]
198+ min_val = stats [f'{ metric } _min' ]
199+ max_val = stats [f'{ metric } _max' ]
200+ print (f" { metric .replace ('_' , ' ' ).title ():20s} "
201+ f"mean={ mean :.4f} ±{ std :.4f} "
202+ f"[{ min_val :.4f} , { max_val :.4f} ]" )
203+
204+ print ("=" * 70 + "\n " )
205+
206+ def save_json (self , output_file ):
207+ """
208+ Save evaluation results to JSON file.
209+
210+ Parameters:
211+ output_file: Path or str - File path to save evaluation data
212+ """
213+ evaluation_data = {
214+ 'global_metrics' : self .get_global_metrics (),
215+ 'frame_statistics' : self .get_frame_statistics (),
216+ 'per_frame_metrics' : self .frame_metrics ,
217+ 'threshold' : self .threshold ,
218+ 'num_frames' : len (self .frame_metrics )
219+ }
220+
221+ output_path = Path (output_file )
222+ output_path .parent .mkdir (parents = True , exist_ok = True )
223+
224+ with open (output_path , 'w' ) as f :
225+ json .dump (evaluation_data , f , indent = 2 )
226+
227+ print (f"Evaluation metrics saved to: { output_path } " )
228+
229+
230+ def evaluate_model_on_dataset (model , dataset , device , use_amp = False ,
231+ amp_dtype = torch .float16 , threshold = 0.5 ):
232+ """
233+ Evaluate model on entire dataset and return metrics.
234+
235+ Parameters:
236+ model: nn.Module - The trained model
237+ dataset: Dataset - Dataset to evaluate on (XPointDataset, not patch dataset)
238+ device: torch.device - Device to run evaluation on
239+ use_amp: bool - Whether to use automatic mixed precision
240+ amp_dtype: torch.dtype - Data type for mixed precision
241+ threshold: float - Threshold for binary classification
242+
243+ Returns:
244+ ModelEvaluator - Evaluator object with computed metrics
245+ """
246+ model .eval ()
247+ evaluator = ModelEvaluator (threshold = threshold )
248+
249+ with torch .no_grad ():
250+ for item in dataset :
251+ fnum = item ["fnum" ]
252+ all_torch = item ["all" ].unsqueeze (0 ).to (device )
253+ mask_gt = item ["mask" ][0 ].cpu ().numpy () # Remove channel dimension
254+
255+ # Get prediction
256+ with autocast (device_type = 'cuda' , dtype = amp_dtype , enabled = use_amp ):
257+ pred_mask = model (all_torch )
258+ pred_prob = torch .sigmoid (pred_mask )
259+
260+ # Convert to numpy (handle BFloat16)
261+ pred_prob_np = pred_prob [0 , 0 ].float ().cpu ().numpy ()
262+
263+ # Add to evaluator
264+ evaluator .add_frame (pred_prob_np , mask_gt , frame_id = fnum )
265+
266+ return evaluator
0 commit comments