-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathweighted_majority_ensemble.py
More file actions
107 lines (95 loc) · 5.24 KB
/
weighted_majority_ensemble.py
File metadata and controls
107 lines (95 loc) · 5.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from chebifier.ensemble.base_ensemble import BaseEnsemble
class WMVwithPPVNPVEnsemble(BaseEnsemble):
def __init__(
self, config_path=None, weighting_strength=1, weighting_exponent=1, **kwargs
):
"""WMV ensemble that weights models based on their class-wise positive / negative predictive values. For each class, the weight is calculated as:
weight = (weighting_strength * PPV + (1 - weighting_strength)) ** weighting_exponent
where PPV is the class-specific positive predictive value of the model on the validation set
or (if the prediction is negative):
weight = (weighting_strength * NPV + (1 - weighting_strength)) ** weighting_exponent
where NPV is the class-specific negative predictive value of the model on the validation set.
"""
super().__init__(config_path, **kwargs)
self.weighting_strength = weighting_strength
self.weighting_exponent = weighting_exponent
def calculate_classwise_weights(self, predicted_classes):
"""
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
(default: 1) multiplied by the class-specific positive / negative weight (default 1).
"""
positive_weights = torch.ones(len(predicted_classes), len(self.models))
negative_weights = torch.ones(len(predicted_classes), len(self.models))
for j, model in enumerate(self.models):
positive_weights[:, j] *= model.model_weight
negative_weights[:, j] *= model.model_weight
if model.classwise_weights is None:
continue
for cls, weights in model.classwise_weights.items():
if cls not in predicted_classes:
continue
ppv = (
weights["TP"] / (weights["TP"] + weights["FP"])
if (weights["TP"] + weights["FP"]) > 0
else 1.0
)
npv = (
weights["TN"] / (weights["TN"] + weights["FN"])
if (weights["TN"] + weights["FN"]) > 0
else 1.0
)
positive_weights[predicted_classes[cls], j] *= (
ppv * self.weighting_strength + (1 - self.weighting_strength)
) ** self.weighting_exponent
negative_weights[predicted_classes[cls], j] *= (
npv * self.weighting_strength + (1 - self.weighting_strength)
) ** self.weighting_exponent
if self.verbose_output:
print(
"Calculated model weightings. The averages for positive / negative weights are:"
)
for i, model in enumerate(self.models):
print(
f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}"
)
return positive_weights, negative_weights
class WMVwithF1Ensemble(BaseEnsemble):
def __init__(
self, config_path=None, weighting_strength=1, weighting_exponent=6.25, **kwargs
):
"""WMV ensemble that weights models based on their class-wise F1 scores. For each class, the weight is calculated as:
weight = model_weight * (weighting_strength * F1 + (1 - weighting_strength)) ** weighting_exponent
where F1 is the class-specific F1 score ("trust") of the model on the validation set.
"""
super().__init__(config_path, **kwargs)
self.weighting_strength = weighting_strength
self.weighting_exponent = weighting_exponent
def calculate_classwise_weights(self, predicted_classes):
"""
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
(default: 1) multiplied by (1 + the class-specific validation-f1 (default 1)).
"""
weights_by_cls = torch.ones(len(predicted_classes), len(self.models))
for j, model in enumerate(self.models):
weights_by_cls[:, j] *= model.model_weight
if model.classwise_weights is None:
continue
for cls, weights in model.classwise_weights.items():
if cls in predicted_classes:
if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0:
f1 = (
2
* weights["TP"]
/ (2 * weights["TP"] + weights["FP"] + weights["FN"])
)
weights_by_cls[predicted_classes[cls], j] *= (
self.weighting_strength * f1 + 1 - self.weighting_strength
) ** self.weighting_exponent
if self.verbose_output:
print("Calculated model weightings. The average weights are:")
for i, model in enumerate(self.models):
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")
return weights_by_cls, weights_by_cls