Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 15 additions & 28 deletions src/pyrecest/utils/association_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,7 @@

@dataclass(frozen=True, init=False)
class NamedPairwiseFeatureSchema:
"""Named schema for building pairwise feature tensors from components.

Parameters
----------
feature_names:
Ordered names of the feature planes to stack on the final tensor axis.
If no transform is registered for a name, the same key is looked up in
``components`` when building a tensor.
transforms:
Optional mapping from feature name to a callable receiving the full
component mapping and returning the feature plane. This keeps derived
features domain-specific without baking their semantics into PyRecEst.
"""
"""Named schema for building pairwise feature tensors from components."""

feature_names: tuple[str, ...]
transforms: Mapping[str, FeatureTransform]
Expand All @@ -63,14 +51,12 @@ def __iter__(self) -> Iterator[str]:
return iter(self.feature_names)

def feature_index(self, feature_name: str) -> int:
"""Return the last-axis index of a named feature."""
try:
return self.feature_names.index(feature_name)
except ValueError as exc:
raise KeyError(f"Unknown feature name {feature_name!r}") from exc

def build_tensor(self, components: Mapping[str, Any]) -> Any:
"""Build a ``(..., n_features)`` tensor from named pairwise components."""
return pairwise_feature_tensor(
components, self.feature_names, transforms=self.transforms
)
Expand Down Expand Up @@ -112,12 +98,7 @@ def pairwise_feature_tensor(

@dataclass(frozen=True, init=False)
class CalibratedPairwiseAssociationModel:
"""Association model wrapper that keeps a named pairwise feature schema.

The wrapped model may expose ``predict_match_probability``,
``predict_proba``, or ``pairwise_cost_matrix``. Component mappings are first
converted to tensors via ``schema`` before being passed to the model.
"""
"""Association model wrapper that keeps a named pairwise feature schema."""

model: Any
schema: NamedPairwiseFeatureSchema
Expand Down Expand Up @@ -149,15 +130,12 @@ def __init__(

@property
def feature_names(self) -> tuple[str, ...]:
"""Ordered feature names expected by the wrapped model."""
return self.schema.feature_names

def build_feature_tensor(self, components: Mapping[str, Any]) -> Any:
"""Build a feature tensor using this model's schema."""
return self.schema.build_tensor(components)

def predict_match_probability(self, features_or_components: Any) -> Any:
"""Return calibrated match probabilities for tensors or components."""
features = self._features_from_components_or_tensor(features_or_components)
if hasattr(self.model, "predict_match_probability"):
probabilities = self.model.predict_match_probability(features)
Expand All @@ -175,7 +153,6 @@ def predict_match_probability(self, features_or_components: Any) -> Any:
def pairwise_probability_matrix_from_components(
self, components: Mapping[str, Any]
) -> Any:
"""Convert components into calibrated pairwise probabilities."""
return self.predict_match_probability(components)

def pairwise_cost_matrix_from_components(
Expand All @@ -184,7 +161,6 @@ def pairwise_cost_matrix_from_components(
*,
mode: _COST_MODE = "negative_log_probability",
) -> Any:
"""Convert components into calibrated assignment costs."""
return self.pairwise_cost_matrix(components, mode=mode)

def pairwise_cost_matrix(
Expand All @@ -193,7 +169,6 @@ def pairwise_cost_matrix(
*,
mode: _COST_MODE = "negative_log_probability",
) -> Any:
"""Convert features or components into an assignment cost matrix."""
probabilities = clip(
self.predict_match_probability(features_or_components),
self.probability_clip,
Expand All @@ -216,7 +191,7 @@ def _predict_proba_probability(self, features: Any) -> Any:
self.model.predict_proba(flattened_features), dtype=float64
)
if probabilities.ndim >= 2 and probabilities.shape[-1] == 2:
probabilities = probabilities[..., 1]
probabilities = probabilities[..., self._predict_proba_positive_class_index()]
elif probabilities.ndim != 1:
raise ValueError(
"predict_proba must return probabilities with shape (n_samples,) or (n_samples, 2)"
Expand All @@ -225,6 +200,18 @@ def _predict_proba_probability(self, features: Any) -> Any:
return asarray(probabilities[0], dtype=float64)
return probabilities.reshape(original_shape)

def _predict_proba_positive_class_index(self) -> int:
classes = getattr(self.model, "classes_", None)
if classes is None:
return 1
classes = asarray(classes).reshape(-1)
if classes.shape != (2,):
return 1
for class_index, class_label in enumerate(classes):
if bool(class_label == 1):
return class_index
return 1


def _normalize_feature_names(feature_names: Sequence[str]) -> tuple[str, ...]:
if isinstance(feature_names, str):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_association_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,23 @@ def test_calibrated_pairwise_association_model_uses_named_components(self):
npt.assert_array_equal(argmax(probabilities, axis=1), array([0, 1]))
npt.assert_array_equal(argmin(costs, axis=1), array([0, 1]))

def test_calibrated_predict_proba_respects_classes_order(self):
class ReversedClassPredictProbaModel:
classes_ = array([1, 0])

def predict_proba(self, features):
del features
return array([[0.8, 0.2], [0.3, 0.7]])

calibrated_model = CalibratedPairwiseAssociationModel(
ReversedClassPredictProbaModel(), feature_names=("distance", "similarity")
)
features = array([[0.1, 0.9], [2.0, 0.1]])

probabilities = calibrated_model.predict_match_probability(features)

npt.assert_allclose(probabilities, array([0.8, 0.3]))


if __name__ == "__main__":
unittest.main()
Loading