Skip to content

Commit 532650e

Browse files
committed
name in the first call
1 parent 7ba4cf7 commit 532650e

2 files changed

Lines changed: 16 additions & 3 deletions

File tree

CompStats/interface.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class Perf(object):
9090
9191
"""
9292
def __init__(self, y_true, *y_pred,
93+
name:str=None,
9394
score_func=balanced_accuracy_score,
9495
error_func=None,
9596
num_samples: int=500,
@@ -100,8 +101,13 @@ def __init__(self, y_true, *y_pred,
100101
self.score_func = score_func
101102
self.error_func = error_func
102103
algs = {}
103-
for k, v in enumerate(y_pred):
104-
algs[f'alg-{k+1}'] = np.asanyarray(v)
104+
if name is not None:
105+
if isinstance(name, str):
106+
name = [name]
107+
else:
108+
name = [f'alg-{k+1}' for k, _ in enumerate(y_pred)]
109+
for key, v in zip(name, y_pred):
110+
algs[key] = np.asanyarray(v)
105111
algs.update(**kwargs)
106112
self.predictions = algs
107113
self.y_true = y_true
@@ -519,7 +525,7 @@ def y_true(self, value):
519525
algs[c] = value[c].to_numpy()
520526
self.predictions.update(algs)
521527
return
522-
self._y_true = value
528+
self._y_true = np.asanyarray(value)
523529

524530
@property
525531
def score_func(self):

CompStats/tests/test_interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323
from CompStats.tests.test_performance import DATA
2424

2525

26+
def test_Perf_name():
27+
"""Test Perf name keyword"""
28+
from CompStats.metrics import f1_score
29+
score = f1_score([1, 0, 1], [1, 0, 0], name='algo')
30+
assert 'algo' in score.predictions
31+
32+
2633
def test_Perf_plot_col_wrap():
2734
"""Test plot when 2 classes"""
2835
from CompStats.metrics import f1_score

0 commit comments

Comments
 (0)