diff --git a/src/pyrecest/utils/association_features.py b/src/pyrecest/utils/association_features.py index df6e8fdc2..526fc1465 100644 --- a/src/pyrecest/utils/association_features.py +++ b/src/pyrecest/utils/association_features.py @@ -26,19 +26,7 @@ @dataclass(frozen=True, init=False) class NamedPairwiseFeatureSchema: - """Named schema for building pairwise feature tensors from components. - - Parameters - ---------- - feature_names: - Ordered names of the feature planes to stack on the final tensor axis. - If no transform is registered for a name, the same key is looked up in - ``components`` when building a tensor. - transforms: - Optional mapping from feature name to a callable receiving the full - component mapping and returning the feature plane. This keeps derived - features domain-specific without baking their semantics into PyRecEst. - """ + """Named schema for building pairwise feature tensors from components.""" feature_names: tuple[str, ...] transforms: Mapping[str, FeatureTransform] @@ -63,14 +51,12 @@ def __iter__(self) -> Iterator[str]: return iter(self.feature_names) def feature_index(self, feature_name: str) -> int: - """Return the last-axis index of a named feature.""" try: return self.feature_names.index(feature_name) except ValueError as exc: raise KeyError(f"Unknown feature name {feature_name!r}") from exc def build_tensor(self, components: Mapping[str, Any]) -> Any: - """Build a ``(..., n_features)`` tensor from named pairwise components.""" return pairwise_feature_tensor( components, self.feature_names, transforms=self.transforms ) @@ -112,12 +98,7 @@ def pairwise_feature_tensor( @dataclass(frozen=True, init=False) class CalibratedPairwiseAssociationModel: - """Association model wrapper that keeps a named pairwise feature schema. - - The wrapped model may expose ``predict_match_probability``, - ``predict_proba``, or ``pairwise_cost_matrix``. Component mappings are first - converted to tensors via ``schema`` before being passed to the model. - """ + """Association model wrapper that keeps a named pairwise feature schema.""" model: Any schema: NamedPairwiseFeatureSchema @@ -149,15 +130,12 @@ def __init__( @property def feature_names(self) -> tuple[str, ...]: - """Ordered feature names expected by the wrapped model.""" return self.schema.feature_names def build_feature_tensor(self, components: Mapping[str, Any]) -> Any: - """Build a feature tensor using this model's schema.""" return self.schema.build_tensor(components) def predict_match_probability(self, features_or_components: Any) -> Any: - """Return calibrated match probabilities for tensors or components.""" features = self._features_from_components_or_tensor(features_or_components) if hasattr(self.model, "predict_match_probability"): probabilities = self.model.predict_match_probability(features) @@ -175,7 +153,6 @@ def predict_match_probability(self, features_or_components: Any) -> Any: def pairwise_probability_matrix_from_components( self, components: Mapping[str, Any] ) -> Any: - """Convert components into calibrated pairwise probabilities.""" return self.predict_match_probability(components) def pairwise_cost_matrix_from_components( @@ -184,7 +161,6 @@ def pairwise_cost_matrix_from_components( *, mode: _COST_MODE = "negative_log_probability", ) -> Any: - """Convert components into calibrated assignment costs.""" return self.pairwise_cost_matrix(components, mode=mode) def pairwise_cost_matrix( @@ -193,7 +169,6 @@ def pairwise_cost_matrix( *, mode: _COST_MODE = "negative_log_probability", ) -> Any: - """Convert features or components into an assignment cost matrix.""" probabilities = clip( self.predict_match_probability(features_or_components), self.probability_clip, @@ -216,7 +191,7 @@ def _predict_proba_probability(self, features: Any) -> Any: self.model.predict_proba(flattened_features), dtype=float64 ) if probabilities.ndim >= 2 and probabilities.shape[-1] == 2: - probabilities = probabilities[..., 1] + probabilities = probabilities[..., self._predict_proba_positive_class_index()] elif probabilities.ndim != 1: raise ValueError( "predict_proba must return probabilities with shape (n_samples,) or (n_samples, 2)" @@ -225,6 +200,18 @@ def _predict_proba_probability(self, features: Any) -> Any: return asarray(probabilities[0], dtype=float64) return probabilities.reshape(original_shape) + def _predict_proba_positive_class_index(self) -> int: + classes = getattr(self.model, "classes_", None) + if classes is None: + return 1 + classes = asarray(classes).reshape(-1) + if classes.shape != (2,): + return 1 + for class_index, class_label in enumerate(classes): + if bool(class_label == 1): + return class_index + return 1 + def _normalize_feature_names(feature_names: Sequence[str]) -> tuple[str, ...]: if isinstance(feature_names, str): diff --git a/tests/test_association_models.py b/tests/test_association_models.py index d39b4c3f8..fcfaaf741 100644 --- a/tests/test_association_models.py +++ b/tests/test_association_models.py @@ -268,6 +268,23 @@ def test_calibrated_pairwise_association_model_uses_named_components(self): npt.assert_array_equal(argmax(probabilities, axis=1), array([0, 1])) npt.assert_array_equal(argmin(costs, axis=1), array([0, 1])) + def test_calibrated_predict_proba_respects_classes_order(self): + class ReversedClassPredictProbaModel: + classes_ = array([1, 0]) + + def predict_proba(self, features): + del features + return array([[0.8, 0.2], [0.3, 0.7]]) + + calibrated_model = CalibratedPairwiseAssociationModel( + ReversedClassPredictProbaModel(), feature_names=("distance", "similarity") + ) + features = array([[0.1, 0.9], [2.0, 0.1]]) + + probabilities = calibrated_model.predict_match_probability(features) + + npt.assert_allclose(probabilities, array([0.8, 0.3])) + if __name__ == "__main__": unittest.main()