diff --git a/src/pyrecest/utils/association_features.py b/src/pyrecest/utils/association_features.py index 4bd2247f9..d4dfb1822 100644 --- a/src/pyrecest/utils/association_features.py +++ b/src/pyrecest/utils/association_features.py @@ -191,7 +191,9 @@ def _predict_proba_probability(self, features: Any) -> Any: self.model.predict_proba(flattened_features), dtype=float64 ) if probabilities.ndim >= 2 and probabilities.shape[-1] == 2: - probabilities = probabilities[..., self._predict_proba_positive_class_index()] + probabilities = probabilities[ + ..., self._predict_proba_positive_class_index() + ] elif probabilities.ndim != 1: raise ValueError( "predict_proba must return probabilities with shape (n_samples,) or (n_samples, 2)" diff --git a/tests/test_quadratic_assignment_defaults.py b/tests/test_quadratic_assignment_defaults.py index 03382d73f..3595408bb 100644 --- a/tests/test_quadratic_assignment_defaults.py +++ b/tests/test_quadratic_assignment_defaults.py @@ -1,7 +1,6 @@ import importlib.util import numpy as np - from pyrecest._backend import numpy as numpy_backend pytorch_backend = None