diff --git a/src/pyrecest/utils/association_features.py b/src/pyrecest/utils/association_features.py index 4bd2247f9..d8287f153 100644 --- a/src/pyrecest/utils/association_features.py +++ b/src/pyrecest/utils/association_features.py @@ -205,14 +205,23 @@ def _predict_proba_positive_class_index(self) -> int: if len(classes) != 2: return 1 for class_index, class_label in enumerate(classes): - try: - if bool(class_label == 1): - return class_index - except (TypeError, ValueError): - continue + if _is_positive_binary_label(class_label): + return class_index return 1 +def _is_positive_binary_label(class_label: Any) -> bool: + try: + if bool(class_label == 1): + return True + except (TypeError, ValueError): + pass + if isinstance(class_label, str): + normalized_label = class_label.strip().casefold() + return normalized_label in {"1", "true"} + return False + + def _class_labels_to_list(classes: Any) -> list[Any]: if classes is None: return [] diff --git a/tests/test_calibrated_association_predict_proba_string_labels.py b/tests/test_calibrated_association_predict_proba_string_labels.py new file mode 100644 index 000000000..4371ceb3c --- /dev/null +++ b/tests/test_calibrated_association_predict_proba_string_labels.py @@ -0,0 +1,29 @@ +import unittest + +import numpy.testing as npt + +from pyrecest.backend import array +from pyrecest.utils import CalibratedPairwiseAssociationModel + + +class TestCalibratedAssociationPredictProbaStringLabels(unittest.TestCase): + def test_predict_proba_respects_stringified_binary_class_order(self): + class StringBinaryPredictProbaModel: + classes_ = ["1", "0"] + + def predict_proba(self, features): + del features + return array([[0.8, 0.2], [0.3, 0.7]]) + + calibrated_model = CalibratedPairwiseAssociationModel( + StringBinaryPredictProbaModel(), feature_names=("distance", "similarity") + ) + features = array([[0.1, 0.9], [2.0, 0.1]]) + + probabilities = calibrated_model.predict_match_probability(features) + + npt.assert_allclose(probabilities, array([0.8, 0.3])) + + +if __name__ == "__main__": + unittest.main()