diff --git a/src/pyrecest/utils/association_features.py b/src/pyrecest/utils/association_features.py index 526fc1465..4bd2247f9 100644 --- a/src/pyrecest/utils/association_features.py +++ b/src/pyrecest/utils/association_features.py @@ -201,18 +201,40 @@ def _predict_proba_probability(self, features: Any) -> Any: return probabilities.reshape(original_shape) def _predict_proba_positive_class_index(self) -> int: - classes = getattr(self.model, "classes_", None) - if classes is None: - return 1 - classes = asarray(classes).reshape(-1) - if classes.shape != (2,): + classes = _class_labels_to_list(getattr(self.model, "classes_", None)) + if len(classes) != 2: return 1 for class_index, class_label in enumerate(classes): - if bool(class_label == 1): - return class_index + try: + if bool(class_label == 1): + return class_index + except (TypeError, ValueError): + continue return 1 +def _class_labels_to_list(classes: Any) -> list[Any]: + if classes is None: + return [] + if hasattr(classes, "detach"): + classes = classes.detach().cpu() + if hasattr(classes, "reshape"): + try: + classes = classes.reshape(-1) + except TypeError: + pass + if hasattr(classes, "tolist"): + labels = classes.tolist() + else: + try: + labels = list(classes) + except TypeError: + return [] + if isinstance(labels, list): + return labels + return [labels] + + def _normalize_feature_names(feature_names: Sequence[str]) -> tuple[str, ...]: if isinstance(feature_names, str): raise ValueError("feature_names must be a sequence of names, not a string") diff --git a/tests/test_association_models.py b/tests/test_association_models.py index fcfaaf741..e07a7f188 100644 --- a/tests/test_association_models.py +++ b/tests/test_association_models.py @@ -285,6 +285,23 @@ def predict_proba(self, features): npt.assert_allclose(probabilities, array([0.8, 0.3])) + def test_calibrated_predict_proba_falls_back_for_nonnumeric_classes(self): + class TextClassPredictProbaModel: + classes_ = ["not_a_match", "match"] + + def predict_proba(self, features): + del features + return array([[0.8, 0.2], [0.3, 0.7]]) + + calibrated_model = CalibratedPairwiseAssociationModel( + TextClassPredictProbaModel(), feature_names=("distance", "similarity") + ) + features = array([[0.1, 0.9], [2.0, 0.1]]) + + probabilities = calibrated_model.predict_match_probability(features) + + npt.assert_allclose(probabilities, array([0.2, 0.7])) + if __name__ == "__main__": unittest.main()