Skip to content
Draft
4 changes: 4 additions & 0 deletions doc/changes/dev/13909.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Batch and vectorise classifier estimation and scoring in
:meth:`mne.decoding.GeneralizingEstimator.score` for ``scoring=None``,
``"accuracy"``, ``"balanced_accuracy"`` and ``"roc_auc"``, by
:newcontrib:`Mathias Sablé-Meyer`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
.. _Martin Luessi: https://github.com/mluessi
.. _Martin Oberg: https://github.com/obergmartin
.. _Martin Schulz: https://github.com/marsipu
.. _Mathias Sablé-Meyer: https://s-m.ac/
.. _Mathieu Scheltienne: https://github.com/mscheltienne
.. _Mathurin Massias: https://mathurinm.github.io/
.. _Mats van Es: https://github.com/matsvanes
Expand Down
139 changes: 128 additions & 11 deletions mne/decoding/search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging

import numpy as np
from scipy.stats import rankdata
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone
from sklearn.metrics import check_scoring
from sklearn.preprocessing import LabelEncoder
Expand Down Expand Up @@ -743,17 +744,133 @@ def _gl_score(estimators, scoring, X, y, pb):
"""
# FIXME: The level parallelization may be a bit high, and might be memory
# consuming. Perhaps need to lower it down to the loop across X slices.
score_shape = [len(estimators), X.shape[-1]]
for jj in range(X.shape[-1]):
for ii, est in enumerate(estimators):
_score = scoring(est, X[..., jj], y)
# Initialize array of predictions on the first score iteration
if (ii == 0) and (jj == 0):
dtype = type(_score)
score = np.zeros(score_shape, dtype)
score[ii, jj, ...] = _score

pb.update(jj * len(estimators) + ii + 1)
n_sample, n_iter = X.shape[0], X.shape[-1]
n_train = len(estimators)
score_shape = [n_train, n_iter]
score = None

# scoring=None goes through sklearn's _PassthroughScorer, which delegates
# to estimator.score(X, y). For a classifier inheriting
# ClassifierMixin.score unchanged, that's accuracy which we now set. We
# compare `type(est).score.__qualname__` rather than `.__name__` because
# the bare name is "score" no matter which class defined the method. A bare
# method has qualname "ClassifierMixin.score", whereas any override
# resolves to "<Subclass>.score". We only take over bare methods.
if len(estimators) and getattr(scoring, "_score_func", None) is None:
qname = getattr(type(estimators[0]).score, "__qualname__", "")
if qname == "ClassifierMixin.score":
scoring = check_scoring(estimators[0], "accuracy")

# Detect whether we can batch the estimator. Recognised:
# * predict,
# * predict_proba
# * decision_function
# * "default" (= predict)
# * A tuple of those: roc_auc = ("decision_function", "predict_proba")
score_func = getattr(scoring, "_score_func", None)
rm = getattr(scoring, "_response_method", None)
valid = {"predict", "predict_proba", "decision_function"}
if rm == "default":
response_method = "predict"
elif isinstance(rm, str) and rm in valid:
response_method = rm
elif isinstance(rm, tuple) and all(m in valid for m in rm):
response_method = rm
else:
response_method = None
can_batch = score_func is not None and response_method is not None

# If we can't batch we do a simple nested loop.
# Covers scoring=None / unrecognised scorers
if not can_batch:
for jj in range(n_iter):
for ii, est in enumerate(estimators):
_score = scoring(est, X[..., jj], y)
if (ii == 0) and (jj == 0):
score = np.zeros(score_shape, type(_score))
score[ii, jj, ...] = _score
pb.update(jj * n_train + ii + 1)
return score

# We can batch; the logic is: reshape X, predict once, reshape back, score
# First: stack X across slices for one batched response call per estimator
X_stack = np.moveaxis(X, -1, 1)
X_stack = X_stack.reshape(n_sample * n_iter, *X_stack.shape[2:])

# Use the provided response method, or pick the first one supported
# by the estimator
if isinstance(response_method, str):
method = response_method
else:
for m in response_method:
if hasattr(estimators[0], m):
method = m
break

# Ensures score_func(..., **kwargs) doesn't crash when scoring._kwargs=None
kwargs = scoring._kwargs or {}

# Batched path: when we recognise score_func, build `batched_score` that
# scores all n_iter slices in a single vectorised reduction. it stays None
# for unrecognised scorers which falls back to nested loops
sign = scoring._sign
batched_score = None
if not kwargs and y.ndim == 1:
name = getattr(score_func, "__name__", "")
if name == "accuracy_score" and response_method == "predict":

def batched_score(y_pred):
return sign * (y_pred == y[:, None]).mean(axis=0)
elif name == "balanced_accuracy_score" and response_method == "predict":
classes = np.unique(y)

def batched_score(y_pred):
return sign * np.stack(
[(y_pred[y == c] == c).mean(axis=0) for c in classes]
).mean(axis=0)
elif name == "roc_auc_score" and method in (
"predict_proba",
"decision_function",
):
classes = np.unique(y)
if len(classes) == 2: # multi-class needs ovr/ovo; defer
pos = y == classes[1]
n_pos, n_neg = int(pos.sum()), int((~pos).sum())
if n_pos and n_neg: # degenerate folds raise downstream in sklearn

def batched_score(y_pred):
# Mann-Whitney U identity with average-rank tie
# correction. Equivalent to sklearn's roc_auc within
# floating point precision, but different computation
ranks = rankdata(y_pred, method="average", axis=0)
return (
sign
* (ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0)
/ (n_pos * n_neg)
)

for ii, est in enumerate(estimators):
y_pred = getattr(est, method)(X_stack)
# predict_proba returns probabilities for both classes; use the
# positive-class probabilities expected by binary scorers
if method == "predict_proba" and y_pred.ndim == 2 and y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
# Now, reshape back the prediction, then score
y_pred = y_pred.reshape((n_sample, n_iter) + y_pred.shape[1:])
# Either we can score with batching (if) or we loop again (else)
if batched_score is not None:
row = batched_score(y_pred)
if ii == 0:
score = np.zeros(score_shape, row.dtype)
score[ii] = row
pb.update((ii + 1) * n_iter)
else:
for jj in range(n_iter):
_score = sign * score_func(y, y_pred[:, jj], **kwargs)
if (ii == 0) and (jj == 0):
score = np.zeros(score_shape, type(_score))
score[ii, jj, ...] = _score
pb.update(ii * n_iter + jj + 1)
return score


Expand Down
11 changes: 8 additions & 3 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import pytest
from numpy.testing import assert_array_equal, assert_equal
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

sklearn = pytest.importorskip("sklearn")

Expand Down Expand Up @@ -246,7 +246,11 @@ def test_generalization_light(metadata_routing):
gl.fit(X, y)
score = gl.score(X, y)
auc = roc_auc_score(y, gl.estimators_[0].predict_proba(X[..., 0])[..., 1])
assert_equal(score[0, 0], auc)

# The rank identity implemented when batching gives the same AUC as sklearn
# within floating point precision, but implements it with different
# operations. A bit-exact match would need a loop, defeating the batching.
assert_allclose(score[0, 0], auc)

for scoring in ["foo", 999]:
gl = GeneralizingEstimator(logreg, scoring=scoring)
Expand All @@ -267,7 +271,8 @@ def test_generalization_light(metadata_routing):
[roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds]
for _y_preds in gl.decision_function(X).transpose(1, 2, 0)
]
assert_array_equal(score, manual_score)
# allclose instead of equal: see above, batching roc_auc forces this.
assert_allclose(score, manual_score)

# n_jobs
gl = GeneralizingEstimator(logreg, n_jobs=2)
Expand Down
Loading