Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/pyrecest/utils/association_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,40 @@ def _predict_proba_probability(self, features: Any) -> Any:
return probabilities.reshape(original_shape)

def _predict_proba_positive_class_index(self) -> int:
classes = getattr(self.model, "classes_", None)
if classes is None:
return 1
classes = asarray(classes).reshape(-1)
if classes.shape != (2,):
classes = _class_labels_to_list(getattr(self.model, "classes_", None))
if len(classes) != 2:
return 1
for class_index, class_label in enumerate(classes):
if bool(class_label == 1):
return class_index
try:
if bool(class_label == 1):
return class_index
except (TypeError, ValueError):
continue
return 1


def _class_labels_to_list(classes: Any) -> list[Any]:
if classes is None:
return []
if hasattr(classes, "detach"):
classes = classes.detach().cpu()
if hasattr(classes, "reshape"):
try:
classes = classes.reshape(-1)
except TypeError:
pass
if hasattr(classes, "tolist"):
labels = classes.tolist()
else:
try:
labels = list(classes)
except TypeError:
return []
if isinstance(labels, list):
return labels
return [labels]


def _normalize_feature_names(feature_names: Sequence[str]) -> tuple[str, ...]:
if isinstance(feature_names, str):
raise ValueError("feature_names must be a sequence of names, not a string")
Expand Down
17 changes: 17 additions & 0 deletions tests/test_association_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ def predict_proba(self, features):

npt.assert_allclose(probabilities, array([0.8, 0.3]))

def test_calibrated_predict_proba_falls_back_for_nonnumeric_classes(self):
class TextClassPredictProbaModel:
classes_ = ["not_a_match", "match"]

def predict_proba(self, features):
del features
return array([[0.8, 0.2], [0.3, 0.7]])

calibrated_model = CalibratedPairwiseAssociationModel(
TextClassPredictProbaModel(), feature_names=("distance", "similarity")
)
features = array([[0.1, 0.9], [2.0, 0.1]])

probabilities = calibrated_model.predict_match_probability(features)

npt.assert_allclose(probabilities, array([0.2, 0.7]))


if __name__ == "__main__":
unittest.main()
Loading