From 511b03c98889c0bf1f006ecc6ab2490609a4fc95 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 9 Jan 2026 15:27:57 +0900 Subject: [PATCH 1/2] add check fitted logic --- pearsonify/__init__.py | 2 +- pearsonify/wrapper.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pearsonify/__init__.py b/pearsonify/__init__.py index 3e0058d..4d8fb73 100644 --- a/pearsonify/__init__.py +++ b/pearsonify/__init__.py @@ -1,3 +1,3 @@ __version__ = "1.0.1" -from .wrapper import Pearsonify \ No newline at end of file +from .wrapper import Pearsonify diff --git a/pearsonify/wrapper.py b/pearsonify/wrapper.py index b2573c5..dd676ad 100644 --- a/pearsonify/wrapper.py +++ b/pearsonify/wrapper.py @@ -1,10 +1,12 @@ -import numpy as np import matplotlib.pyplot as plt +import numpy as np from sklearn.base import BaseEstimator +from sklearn.utils.validation import NotFittedError, check_is_fitted + from .utils import ( - compute_pearson_residuals, - compute_confidence_intervals, calculate_coverage, + compute_confidence_intervals, + compute_pearson_residuals, ) @@ -24,7 +26,15 @@ def __init__(self, estimator: BaseEstimator, alpha=0.05): def fit(self, X_train, y_train, X_cal, y_cal): """Fit the model and compute Pearson residual-based quantile from calibration data.""" # Train the model if it's not already fitted - self.estimator.fit(X_train, y_train) + try: + check_is_fitted(self.estimator) + if not hasattr(self.estimator, "predict_proba"): + raise TypeError("The estimator must have 'predict_proba' method.") + except TypeError as e: + raise TypeError(f"Estimator validation failed: {e}") from e + except NotFittedError: + # Attempt to fit the estimator if not already fitted + self.estimator.fit(X_train, y_train) # Compute residuals on calibration set y_cal_pred_proba = self.estimator.predict_proba(X_cal)[:, 1] From ea7c0bb9e12e51907ddd009be88f20800e7a32dc Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 9 Jan 2026 15:38:23 +0900 Subject: [PATCH 2/2] follow the sourcery comment --- pearsonify/wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pearsonify/wrapper.py b/pearsonify/wrapper.py index dd676ad..ad9df74 100644 --- a/pearsonify/wrapper.py +++ b/pearsonify/wrapper.py @@ -28,8 +28,8 @@ def fit(self, X_train, y_train, X_cal, y_cal): # Train the model if it's not already fitted try: check_is_fitted(self.estimator) - if not hasattr(self.estimator, "predict_proba"): - raise TypeError("The estimator must have 'predict_proba' method.") + if not callable(getattr(self.estimator, "predict_proba", None)): + raise TypeError("The estimator must have a callable 'predict_proba' method.") except TypeError as e: raise TypeError(f"Estimator validation failed: {e}") from e except NotFittedError: