Skip to content

Commit 42444f0

Browse files
committed
Implemented offset and predict in anomaly
1 parent dbd4d98 commit 42444f0

3 files changed

Lines changed: 66 additions & 84 deletions

File tree

distclassipy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from .distances import _ALL_METRICS, _UNIQUE_METRICS
3030

31-
__version__ = "0.2.2a3"
31+
__version__ = "0.2.2a4"
3232

3333
__all__ = [
3434
"DistanceMetricClassifier",

distclassipy/anomaly.py

Lines changed: 65 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
A module for distance anomalies.
3+
4+
This module implements the methods for distance anomalies.
5+
"""
6+
17
import numpy as np
28
from sklearn.base import BaseEstimator, OutlierMixin
39
from sklearn.utils.validation import check_is_fitted, check_array
@@ -26,9 +32,12 @@ class DistanceAnomaly(OutlierMixin, BaseEstimator):
2632
cluster_agg : {'min', 'mean', 'median'}, default='min'
2733
The aggregation method for distances to different class centroids for a
2834
single metric.
29-
- 'min': An object's distance is its distance to the *nearest* known class.
30-
- 'mean': An object's distance is its mean distance to all *nearest* known classes.
31-
- 'median': A more robust measure of an object's typical distance to all classes.
35+
- 'min': An object's distance is its distance to the *nearest* known
36+
class.
37+
- 'mean': An object's distance is its mean distance to all *nearest*
38+
known classes.
39+
- 'median': A more robust measure of an object's typical distance
40+
to all classes.
3241
3342
metric_agg : {'median', 'mean', 'min', 'percentile_25'}, default='median'
3443
The method to aggregate scores from the ensemble of metrics.
@@ -81,8 +90,10 @@ def __init__(
8190

8291
def fit(self, X: np.ndarray, y: np.ndarray) -> "DistanceAnomaly":
8392
"""
84-
Fit the anomaly detector by training the underlying DistanceMetricClassifier
85-
on the normal data.
93+
Fit the anomaly detector.
94+
95+
Fit the anomaly detector by training the underlying
96+
DistanceMetricClassifier on the normal data.
8697
8798
Parameters
8899
----------
@@ -109,15 +120,18 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "DistanceAnomaly":
109120
self.metrics_ = self.metrics
110121

111122
# Calculate anomaly threshold based on train scores
112-
# train_scores = self.decision_function(X)
113-
# self.offset_ = np.quantile(train_scores, 1.0 - self.contamination)
123+
train_scores = self.decision_function(X)
124+
125+
self.offset_ = np.quantile(train_scores, 1.0 - self.contamination)
114126

115127
return self
116128

117129
def decision_function(self, X: np.ndarray) -> np.ndarray:
118130
"""
119-
Calculate the raw anomaly score for each sample. Higher scores are
120-
more anomalous.
131+
Decision function.
132+
133+
Calculates the raw anomaly score for each sample.
134+
Higher scores are more anomalous.
121135
122136
Parameters
123137
----------
@@ -135,11 +149,10 @@ def decision_function(self, X: np.ndarray) -> np.ndarray:
135149
metric_scores = []
136150

137151
for metric in self.metrics_:
138-
# Get dataframe for distances to all centroids from dcpy
139152
self.clf_.predict_and_analyse(X, metric=metric)
140153
dist_df = self.clf_.centroid_dist_df_
141154

142-
# 1. Aggregate distances across clusters the current metric
155+
# Aggregate distances across clusters the current metric
143156
if self.cluster_agg == "min":
144157
score_for_metric = dist_df.min(axis=1).values
145158
elif self.cluster_agg == "median":
@@ -153,83 +166,63 @@ def decision_function(self, X: np.ndarray) -> np.ndarray:
153166

154167
metric_scores_arr = np.array(metric_scores).T # shape (n_samples, n_metrics)
155168
# remove infinities
156-
metric_scores_arr[metric_scores_arr == np.inf] = 1e9 # A large number
157-
metric_scores_arr[metric_scores_arr == -np.inf] = -1e9 # A large negative number
158-
169+
metric_scores_arr[metric_scores_arr == np.inf] = 1e9 # A large number
170+
metric_scores_arr[metric_scores_arr == -np.inf] = (
171+
-1e9
172+
) # A large negative number
159173

160174
if self.normalize_scores:
161-
# Scale scores for each metric (column) to be between 0 and 1
162-
# Compare with Rio notebook once.
175+
col_means = np.nanmean(metric_scores_arr, axis=0)
176+
inds = np.where(np.isnan(metric_scores_arr))
177+
metric_scores_arr[inds] = np.take(col_means, inds[1])
163178
metric_scores_arr = minmax_scale(metric_scores_arr, axis=0)
164-
165-
# 2. Aggregate scores across all metrics for final anomaly score
179+
180+
# Aggregate scores across all metrics for final anomaly score
166181
if self.metric_agg == "median":
167-
scores = np.median(metric_scores_arr, axis=1)
182+
scores = np.nanmedian(metric_scores_arr, axis=1)
168183
elif self.metric_agg == "mean":
169-
scores = np.mean(metric_scores_arr, axis=1)
184+
scores = np.nanmean(metric_scores_arr, axis=1)
170185
elif self.metric_agg == "min":
171-
scores = np.min(metric_scores_arr, axis=1)
186+
scores = np.nanmin(metric_scores_arr, axis=1)
172187
elif self.metric_agg == "percentile_25":
173-
scores = np.quantile(metric_scores_arr, 0.25, axis=1)
188+
scores = np.nanquantile(metric_scores_arr, 0.25, axis=1)
174189
else:
175190
raise ValueError(f"Unknown metric_agg method: {self.metric_agg}")
176191

177-
# # Threshold for predict() as per sklearn conventions
178-
# ## NOTE: DATA LEAKAGE CONCERN
179-
# ## FIX LATER
180-
# self.offset_ = np.quantile(scores, (1 - self.contamination))
181-
182192
return scores
183193

184194
def score_samples(self, X: np.ndarray) -> np.ndarray:
185195
"""
186196
Calculate the anomaly score, matching scikit-learn's convention.
187197
188-
Note: Opposite of decision_function. Higher scores mean less anomalous (more normal).
189-
This is for compatibility with tools that expect this behavior, like IsolationForest.
198+
Note: Opposite of decision_function. Higher scores mean less
199+
anomalous (more normal).
200+
This is for compatibility with tools that expect this
201+
behavior, like IsolationForest.
190202
"""
191203
return -self.decision_function(X)
192204

193-
# def predict(self, X: np.ndarray) -> np.ndarray:
194-
# """
195-
# Predict if a particular sample is an inlier (1) or outlie (-1).
196-
197-
# Parameters
198-
# ----------
199-
# X : array-like of shape (n_samples,)
200-
# The input samples.
201-
202-
# Returns
203-
# -------
204-
# is_outlier : ndarray of shape (n_samples,)
205-
# Returns -1 for outliers and 1 for inliers.
206-
# """
207-
# check_is_fitted(self)
208-
# scores = self.decision_function(X)
209-
# is_outlier = np.ones(X.shape[0], dtype=int)
210-
# is_outlier[scores >= self.offset_] = -1
211-
# return is_outlier
212-
213-
# def predict(self, X: np.ndarray) -> np.ndarray:
214-
# NOTE: UNCOMMENT AFTER FIXING ABOVE offset_ DATA LEAKAGE CONCERN
215-
# """
216-
# Predict if a particular sample is an inlier or outlier.
217-
218-
# Parameters
219-
# ----------
220-
# X : array-like of shape (n_samples,)
221-
# The input samples.
222-
223-
# Returns
224-
# -------
225-
# is_outlier : ndarray of shape (n_samples,)
226-
# Returns -1 for outliers and 1 for inliers.
227-
# """
228-
# scores = self.decision_function(X)
229-
# is_outlier = np.ones(X.shape[0], dtype=int)
230-
# is_outlier[scores >= self.offset_] = -1
231-
# return is_outlier
232-
233-
234-
# ref:
235-
# DOI: 10.2196/27172
205+
def predict(self, X: np.ndarray) -> np.ndarray:
206+
"""
207+
Predict if a particular sample is an inlier (1) or outlier (-1).
208+
209+
This method uses the threshold learned during the `fit` phase.
210+
211+
Parameters
212+
----------
213+
X : array-like of shape (n_samples, n_features)
214+
The input samples.
215+
216+
Returns
217+
-------
218+
is_outlier : ndarray of shape (n_samples,)
219+
Returns -1 for outliers and 1 for inliers.
220+
"""
221+
check_is_fitted(self)
222+
scores = self.decision_function(X)
223+
224+
# Compare scores against the pre-computed threshold
225+
is_outlier = np.full(X.shape[0], 1, dtype=int)
226+
is_outlier[scores >= self.offset_] = -1
227+
228+
return is_outlier

distclassipy/neighbors.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)