Skip to content

Commit 3f72e4e

Browse files
bruAristimunhaPierreGtchtomMoral
authored
Update evaluation to use new splitters and include updates (#769)
* updating evaluation to use new splitters * cross-subject * including the whats new * updating the whats new * Update docs/source/whats_new.rst Co-authored-by: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Signed-off-by: Bru <b.aristimunha@gmail.com> * simple fit for everybody * updating the splitter * parallel evaluation now * solving the small issue * updating the evaluation * adjusting in the other evaluation too * updating * updating the evaluations * Apply suggestions from code review Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com> Signed-off-by: Bru <b.aristimunha@gmail.com> * updating base * updating the pyproject * trying to solve this shit... * crazy things here.. * too much things at the same time * reverting * evaluation * including acceptance test * forcing two reference results * reverting small detail * updating the pyproject * upgrading the mne version * solving issue with saving * scoring * fixing import --------- Signed-off-by: Bru <b.aristimunha@gmail.com> Signed-off-by: Bru <a.bruno@aluno.ufabc.edu.br> Co-authored-by: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com>
1 parent 81b805a commit 3f72e4e

11 files changed

Lines changed: 261 additions & 221 deletions

File tree

docs/source/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Enhancements
4242
- Adding :func:`moabb.analysis.plotting.dataset_bubble_plot` plus the corresponding tutorial (:gh:`753` by `Pierre Guetschel`_)
4343
- Adding :func:`moabb.datasets.utils.plot_all_datasets` and update the tutorial (:gh:`758` by `Pierre Guetschel`_)
4444
- Improve the dataset model cards in each API page (:gh:`765` by `Pierre Guetschel`_)
45+
- Refactor :class:`moabb.evaluation.CrossSessionEvaluation`, :class:`moabb.evaluation.CrossSubjectEvaluation` and :class:`moabb.evaluation.WithinSessionEvaluation` to use the new splitter classes (:gh:`769` by `Bruno Aristimunha`_)
4546
- Adding tutorial on using mne-features (:gh:`762` by `Alexander de Ranitz`_, `Luuk Neervens`_, `Charlynn van Osch`_ and `Bruno Aristimunha`_)
4647
- Creating tutorial to expose the pre-processing steps (:gh:`771` by `Bruno Aristimunha`_)
4748
- Add function to auto-generate tables for the paper results documentation page (:gh:`785` by `Lucas Heck`_)

examples/advanced_examples/plot_grid_search_withinsession.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"""
1010

1111
import os
12-
from pickle import load
1312

1413
import matplotlib.pyplot as plt
1514
import seaborn as sns
@@ -132,44 +131,3 @@
132131
)
133132
sns.pointplot(data=result, y="score", x="pipeline", ax=axes, palette="Set1")
134133
axes.set_ylabel("ROC AUC")
135-
136-
##########################################################
137-
# Load Best Model Parameter
138-
# -------------------------
139-
# The best model are automatically saved in a pickle file, in the
140-
# results directory. It is possible to load those model for each
141-
# dataset, subject and session. Here, we could see that the grid
142-
# search found a l1_ratio that is different from the baseline
143-
# value.
144-
145-
with open(
146-
"./Results/Models_WithinSession/BNCI2014-001/1/1test/GridSearchEN/fitted_model_best.pkl",
147-
"rb",
148-
) as pickle_file:
149-
GridSearchEN_Session_E = load(pickle_file)
150-
151-
print(
152-
"Best Parameter l1_ratio Session_E GridSearchEN ",
153-
GridSearchEN_Session_E.best_params_["LogistReg__l1_ratio"],
154-
)
155-
156-
print(
157-
"Best Parameter l1_ratio Session_E VanillaEN: ",
158-
pipelines["VanillaEN"].steps[2][1].l1_ratio,
159-
)
160-
161-
with open(
162-
"./Results/Models_WithinSession/BNCI2014-001/1/0train/GridSearchEN/fitted_model_best.pkl",
163-
"rb",
164-
) as pickle_file:
165-
GridSearchEN_Session_T = load(pickle_file)
166-
167-
print(
168-
"Best Parameter l1_ratio Session_T GridSearchEN ",
169-
GridSearchEN_Session_T.best_params_["LogistReg__l1_ratio"],
170-
)
171-
172-
print(
173-
"Best Parameter l1_ratio Session_T VanillaEN: ",
174-
pipelines["VanillaEN"].steps[2][1].l1_ratio,
175-
)

moabb/evaluations/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
WithinSessionEvaluation,
1111
)
1212
from .splitters import CrossSessionSplitter, CrossSubjectSplitter, WithinSessionSplitter
13-
from .utils import create_save_path, save_model_cv, save_model_list
13+
from .utils import _create_save_path, _save_model_cv

moabb/evaluations/base.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,23 @@
33
from warnings import warn
44

55
import pandas as pd
6+
from joblib import Parallel, delayed
67
from sklearn.base import BaseEstimator
7-
from sklearn.model_selection import GridSearchCV
88

99
from moabb.analysis import Results
1010
from moabb.datasets.base import BaseDataset
11-
from moabb.evaluations.utils import _convert_sklearn_params_to_optuna
11+
from moabb.evaluations.utils import (
12+
_convert_sklearn_params_to_optuna,
13+
check_search_available,
14+
)
1215
from moabb.paradigms.base import BaseParadigm
1316

1417

18+
search_methods, optuna_available = check_search_available()
19+
1520
log = logging.getLogger(__name__)
1621

1722
# Making the optuna soft dependency
18-
try:
19-
from optuna.integration import OptunaSearchCV
20-
21-
optuna_available = True
22-
except ImportError:
23-
optuna_available = False
24-
25-
if optuna_available:
26-
search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV}
27-
else:
28-
search_methods = {"grid": GridSearchCV}
2923

3024

3125
class BaseEvaluation(ABC):
@@ -83,6 +77,8 @@ class BaseEvaluation(ABC):
8377
optuna, time_out parameters.
8478
"""
8579

80+
search = False
81+
8682
def __init__(
8783
self,
8884
paradigm,
@@ -201,7 +197,6 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
201197
This pipeline must be "fixed" because it will not be trained,
202198
i.e. no call to ``fit`` will be made.
203199
204-
205200
Returns
206201
-------
207202
results: pd.DataFrame
@@ -216,26 +211,44 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
216211
if not (isinstance(pipeline, BaseEstimator)):
217212
raise (ValueError("pipelines must only contains Pipelines " "instance"))
218213

219-
res_per_db = []
220-
for dataset in self.datasets:
221-
log.info("Processing dataset: {}".format(dataset.code))
222-
process_pipeline = self.paradigm.make_process_pipelines(
214+
# Prepare dataset processing parameters
215+
processing_params = [
216+
(
223217
dataset,
224-
return_epochs=self.return_epochs,
225-
return_raws=self.return_raws,
226-
postprocess_pipeline=postprocess_pipeline,
227-
)[0]
228-
# (we only keep the pipeline for the first frequency band, better ideas?)
229-
230-
results = self.evaluate(
231-
dataset,
232-
pipelines,
233-
param_grid=param_grid,
234-
process_pipeline=process_pipeline,
235-
postprocess_pipeline=postprocess_pipeline,
218+
self.paradigm.make_process_pipelines(
219+
dataset,
220+
return_epochs=self.return_epochs,
221+
return_raws=self.return_raws,
222+
postprocess_pipeline=postprocess_pipeline,
223+
)[0],
236224
)
225+
for dataset in self.datasets
226+
]
227+
228+
# Parallel processing...
229+
parallel_results = Parallel(n_jobs=self.n_jobs)(
230+
delayed(
231+
lambda d, p: list(
232+
self.evaluate(
233+
d,
234+
pipelines,
235+
param_grid=param_grid,
236+
process_pipeline=p,
237+
postprocess_pipeline=postprocess_pipeline,
238+
)
239+
)
240+
)(dataset, process_pipeline)
241+
for dataset, process_pipeline in processing_params
242+
)
243+
244+
res_per_db = []
245+
# Process results in order
246+
for (dataset, process_pipeline), results in zip(
247+
processing_params, parallel_results
248+
):
237249
for res in results:
238250
self.push_result(res, pipelines, process_pipeline)
251+
239252
res_per_db.append(
240253
self.results.to_dataframe(
241254
pipelines=pipelines, process_pipeline=process_pipeline
@@ -316,9 +329,12 @@ def _grid_search(self, param_grid, name, grid_clf, inner_cv):
316329
return_train_score=True,
317330
**extra_params,
318331
)
332+
self.search = True
319333
return search
320334
else:
335+
self.search = True
321336
return grid_clf
322337

323338
else:
339+
self.search = False
324340
return grid_clf

0 commit comments

Comments
 (0)