Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/pyrecest/utils/association_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,23 @@ def _predict_proba_positive_class_index(self) -> int:
if len(classes) != 2:
return 1
for class_index, class_label in enumerate(classes):
try:
if bool(class_label == 1):
return class_index
except (TypeError, ValueError):
continue
if _is_positive_binary_label(class_label):
return class_index
return 1


def _is_positive_binary_label(class_label: Any) -> bool:
try:
if bool(class_label == 1):
return True
except (TypeError, ValueError):
pass
if isinstance(class_label, str):
normalized_label = class_label.strip().casefold()
return normalized_label in {"1", "true"}
return False


def _class_labels_to_list(classes: Any) -> list[Any]:
if classes is None:
return []
Expand Down
29 changes: 29 additions & 0 deletions tests/test_calibrated_association_predict_proba_string_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest

import numpy.testing as npt

from pyrecest.backend import array
from pyrecest.utils import CalibratedPairwiseAssociationModel


class TestCalibratedAssociationPredictProbaStringLabels(unittest.TestCase):
def test_predict_proba_respects_stringified_binary_class_order(self):
class StringBinaryPredictProbaModel:
classes_ = ["1", "0"]

def predict_proba(self, features):
del features
return array([[0.8, 0.2], [0.3, 0.7]])

calibrated_model = CalibratedPairwiseAssociationModel(
StringBinaryPredictProbaModel(), feature_names=("distance", "similarity")
)
features = array([[0.1, 0.9], [2.0, 0.1]])

probabilities = calibrated_model.predict_match_probability(features)

npt.assert_allclose(probabilities, array([0.8, 0.3]))


if __name__ == "__main__":
unittest.main()
Loading