Skip to content

Commit 1acc151

Browse files
committed
pearsonr
1 parent 0b67d6c commit 1acc151

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

CompStats/metrics.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from functools import wraps
1515
from sklearn import metrics
16+
from scipy import stats
1617
from CompStats.interface import Perf
1718
from CompStats.utils import metrics_docs
1819

@@ -651,7 +652,6 @@ def inner(y, hy):
651652
**kwargs)
652653

653654

654-
@metrics_docs(hy_name='y_pred', attr_name='score_func')
655655
def d2_absolute_error_score(y_true,
656656
*y_pred,
657657
sample_weight=None,
@@ -672,3 +672,37 @@ def inner(y, hy):
672672
num_samples=num_samples, n_jobs=n_jobs,
673673
use_tqdm=use_tqdm,
674674
**kwargs)
675+
676+
677+
def pearsonr(y_true, *y_pred,
678+
alternative='two-sided', method=None,
679+
num_samples: int=500,
680+
n_jobs: int=-1,
681+
use_tqdm=True,
682+
**kwargs):
683+
""":py:class:`~CompStats.interface.Perf` with :py:func:`~scipy.stats.pearsonr` as :py:attr:`score_func.`
684+
685+
:param y_true: True measurement or could be a pandas.DataFrame where column label 'y' corresponds to the true measurement.
686+
:type y_true: numpy.ndarray or pandas.DataFrame
687+
:param y_pred: Predictions, the algorithms will be identified with alg-k where k=1 is the first argument included in :py:attr:`y_pred.`
688+
:type y_pred: numpy.ndarray
689+
:param kwargs: Predictions, the algorithms will be identified using the keyword
690+
:type kwargs: numpy.ndarray
691+
:param num_samples: Number of bootstrap samples, default=500.
692+
:type num_samples: int
693+
:param n_jobs: Number of jobs to compute the statistic, default=-1 corresponding to use all threads.
694+
:type n_jobs: int
695+
:param use_tqdm: Whether to use tqdm.tqdm to visualize the progress, default=True
696+
:type use_tqdm: bool
697+
"""
698+
699+
@wraps(stats.pearsonr)
700+
def inner(y, hy):
701+
return stats.pearsonr(y, hy,
702+
alternative=alternative,
703+
method=method).statistic
704+
705+
return Perf(y_true, *y_pred, score_func=inner, error_func=None,
706+
num_samples=num_samples, n_jobs=n_jobs,
707+
use_tqdm=use_tqdm,
708+
**kwargs)

CompStats/tests/test_metrics.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,4 +476,21 @@ def test_d2_absolute_error_score():
476476
forest=hy,
477477
num_samples=50)
478478
_ = metrics.d2_absolute_error_score(y_val, hy)
479-
assert _ == perf.statistic
479+
assert _ == perf.statistic
480+
481+
482+
def test_pearsonr():
483+
"""test pearsonr"""
484+
from CompStats.metrics import pearsonr
485+
from scipy import stats
486+
487+
X, y = load_diabetes(return_X_y=True)
488+
_ = train_test_split(X, y, test_size=0.3)
489+
X_train, X_val, y_train, y_val = _
490+
ens = RandomForestRegressor().fit(X_train, y_train)
491+
hy = ens.predict(X_val)
492+
perf = pearsonr(y_val,
493+
forest=hy,
494+
num_samples=50)
495+
_ = stats.pearsonr(y_val, hy)
496+
assert _.statistic == perf.statistic

0 commit comments

Comments
 (0)