-
Notifications
You must be signed in to change notification settings - Fork 771
Expand file tree
/
Copy pathtest_binary_metrics.py
More file actions
120 lines (104 loc) · 4.24 KB
/
test_binary_metrics.py
File metadata and controls
120 lines (104 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import unittest
import numpy as np
from pyhealth.metrics import binary_metrics_fn
class TestBinaryMetrics(unittest.TestCase):
"""Test cases for binary classification metrics."""
def setUp(self):
"""Set up synthetic binary classification data."""
np.random.seed(42)
self.y_true = np.array([0, 0, 1, 1, 0, 1])
self.y_prob = np.array([0.1, 0.4, 0.35, 0.8, 0.2, 0.9])
def test_default_metrics(self):
"""Test that default metrics (pr_auc, roc_auc, f1) are returned."""
result = binary_metrics_fn(self.y_true, self.y_prob)
self.assertIn("pr_auc", result)
self.assertIn("roc_auc", result)
self.assertIn("f1", result)
self.assertEqual(len(result), 3)
def test_accuracy(self):
"""Test accuracy metric with known values."""
result = binary_metrics_fn(
self.y_true, self.y_prob, metrics=["accuracy"],
)
self.assertIn("accuracy", result)
self.assertIsInstance(result["accuracy"], float)
self.assertGreaterEqual(result["accuracy"], 0.0)
self.assertLessEqual(result["accuracy"], 1.0)
def test_all_classification_metrics(self):
"""Test that all supported classification metrics compute."""
all_metrics = [
"pr_auc", "roc_auc", "accuracy", "balanced_accuracy",
"f1", "precision", "recall", "cohen_kappa", "jaccard",
]
result = binary_metrics_fn(
self.y_true, self.y_prob, metrics=all_metrics,
)
for metric in all_metrics:
self.assertIn(metric, result)
self.assertIsInstance(result[metric], float)
@unittest.skip(
"ece_confidence_binary expects 2D arrays but binary_metrics_fn "
"passes 1D - see calibration.py:150"
)
def test_calibration_metrics(self):
"""Test ECE and adaptive ECE metrics."""
result = binary_metrics_fn(
self.y_true, self.y_prob, metrics=["ECE", "ECE_adapt"],
)
self.assertIn("ECE", result)
self.assertIn("ECE_adapt", result)
self.assertGreaterEqual(result["ECE"], 0.0)
self.assertGreaterEqual(result["ECE_adapt"], 0.0)
def test_perfect_predictions(self):
"""Test metrics with perfect predictions."""
y_true = np.array([0, 0, 1, 1])
y_prob = np.array([0.0, 0.0, 1.0, 1.0])
result = binary_metrics_fn(
y_true, y_prob, metrics=["accuracy", "f1"],
)
self.assertEqual(result["accuracy"], 1.0)
self.assertEqual(result["f1"], 1.0)
def test_custom_threshold(self):
"""Test that custom threshold changes predictions."""
result_low = binary_metrics_fn(
self.y_true, self.y_prob,
metrics=["accuracy"], threshold=0.3,
)
result_high = binary_metrics_fn(
self.y_true, self.y_prob,
metrics=["accuracy"], threshold=0.9,
)
# Different thresholds should generally give different results
self.assertIsInstance(result_low["accuracy"], float)
self.assertIsInstance(result_high["accuracy"], float)
def test_metric_values_in_range(self):
"""Test that all metric values are in valid ranges."""
all_metrics = [
"pr_auc", "roc_auc", "accuracy", "balanced_accuracy",
"f1", "precision", "recall", "jaccard",
]
result = binary_metrics_fn(
self.y_true, self.y_prob, metrics=all_metrics,
)
for metric in all_metrics:
self.assertGreaterEqual(
result[metric], 0.0, f"{metric} below 0",
)
self.assertLessEqual(
result[metric], 1.0, f"{metric} above 1",
)
def test_unknown_metric_raises(self):
"""Test that unknown metric name raises ValueError."""
with self.assertRaises(ValueError):
binary_metrics_fn(
self.y_true, self.y_prob, metrics=["nonexistent"],
)
def test_single_metric(self):
"""Test requesting a single metric."""
result = binary_metrics_fn(
self.y_true, self.y_prob, metrics=["roc_auc"],
)
self.assertEqual(len(result), 1)
self.assertIn("roc_auc", result)
if __name__ == "__main__":
unittest.main()