Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Enhancements
- Adding :class:`moabb.dataset.Beetl2021A` and :class:`moabb.dataset.Beetl2021B`(:gh:`675` by `Samuel Boehm_`)
- Adding :class:`moabb.evaluations.splitters.CrossSessionSplitter` (:gh:`720` by `Bruna Lopes`_ and `Bruno Aristimunha`_)
- Adding :class:`moabb.dataset.base.BaseBIDSDataset` and :class:`moabb.dataset.base.LocalBIDSDataset` (:gh:`724` by `Pierre Guetschel`_)

- Adding :class:`moabb.evaluations.CrossDatasetEvaluation` and :class:`moabb.evaluations.splitters.CrossDatasetSplitter` for cross-dataset evaluation, enabling training on one dataset and testing on another (:gh:`703` by `Ali Imran`_)

Bugs
~~~~
Expand Down Expand Up @@ -547,3 +547,4 @@ API changes
.. _AFF: https://github.com/allwaysFindFood
.. _Marco Congedo: https://github.com/Marco-Congedo
.. _Samuel Boehm: https://github.com/Samuel-Boehm
.. _Ali Imran: https://github.com/EazyAl
48 changes: 48 additions & 0 deletions examples/advanced_examples/plot_cross_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Cross-dataset motor imagery classification
===========================================

This example shows how to train on one dataset (BNCI2014_001) and
test on another (Zhou2016) using ``CrossDatasetEvaluation``.
Channel alignment and resampling are handled automatically.
"""

import matplotlib.pyplot as plt
from pyriemann.estimation import Covariances
from pyriemann.spatialfilters import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

from moabb import set_log_level
from moabb.datasets import BNCI2014001, Zhou2016
from moabb.evaluations import CrossDatasetEvaluation
from moabb.paradigms import LeftRightImagery

set_log_level("WARNING")

paradigm = LeftRightImagery()

train_dataset = BNCI2014001()
test_dataset = Zhou2016()

pipelines = {
"CSP+LDA": make_pipeline(Covariances("oas"), CSP(nfilter=6), LDA()),
}

evaluation = CrossDatasetEvaluation(
paradigm=paradigm,
train_datasets=train_dataset,
test_datasets=test_dataset,
)

results = evaluation.process(pipelines)

print(results[["dataset", "subject", "session", "score"]])

fig, ax = plt.subplots(figsize=(8, 5))
results.boxplot(column="score", by="pipeline", ax=ax)
ax.set_title("Cross-dataset: BNCI2014_001 -> Zhou2016")
ax.set_ylabel("Score")
plt.suptitle("")
plt.tight_layout()
plt.show()
3 changes: 2 additions & 1 deletion moabb/evaluations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

# flake8: noqa
from .evaluations import (
CrossDatasetEvaluation,
CrossSessionEvaluation,
CrossSubjectEvaluation,
WithinSessionEvaluation,
)
from .splitters import CrossSessionSplitter, WithinSessionSplitter
from .splitters import CrossDatasetSplitter, CrossSessionSplitter, WithinSessionSplitter
from .utils import create_save_path, save_model_cv, save_model_list
208 changes: 207 additions & 1 deletion moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import Optional, Union

import numpy as np
import pandas as pd
from mne.epochs import BaseEpochs
from sklearn.base import clone
from sklearn.base import BaseEstimator, clone
from sklearn.metrics import get_scorer
from sklearn.model_selection import (
GroupKFold,
Expand Down Expand Up @@ -780,3 +781,208 @@ def evaluate(

def is_valid(self, dataset):
return len(dataset.subject_list) > 1


class CrossDatasetEvaluation(BaseEvaluation):
Comment thread
gcattan marked this conversation as resolved.
"""Cross-dataset evaluation: train on one or more datasets, test on others.

This evaluation trains pipelines on all data from the training datasets
and evaluates per-subject on the test datasets. Channels and sampling
rates are automatically aligned across datasets using
``paradigm.match_all``.

Parameters
----------
train_datasets : Dataset or list of Dataset
Dataset(s) to use for training.
test_datasets : Dataset or list of Dataset
Dataset(s) to use for testing.
**kwargs : dict
Additional parameters passed to BaseEvaluation (paradigm, n_jobs, etc.)

Notes
-----
.. versionadded:: 1.2.1
"""

def __init__(self, train_datasets, test_datasets, **kwargs):
# Normalize to lists
if not isinstance(train_datasets, list):
train_datasets = [train_datasets]
if not isinstance(test_datasets, list):
test_datasets = [test_datasets]

self.train_datasets = train_datasets
self.test_datasets = test_datasets

# Validate non-empty
if not self.train_datasets:
raise ValueError("train_datasets must not be empty")
if not self.test_datasets:
raise ValueError("test_datasets must not be empty")

# Validate no overlap
train_codes = {ds.code for ds in self.train_datasets}
test_codes = {ds.code for ds in self.test_datasets}
overlap = train_codes & test_codes
if overlap:
raise ValueError(f"Datasets cannot be both train and test: {overlap}")

# Pass all datasets to super for paradigm validation
all_datasets = self.train_datasets + self.test_datasets
super().__init__(datasets=all_datasets, **kwargs)

# Align channels and sampling rates across all datasets
self.paradigm.match_all(self.datasets, channel_merge_strategy="intersect")

def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
"""Run cross-dataset evaluation across all pipelines.

Loads data from all train and test datasets, concatenates them,
and uses :class:`CrossDatasetSplitter` to generate per-subject
train/test splits.

Parameters
----------
pipelines : dict of pipeline instance.
A dict containing the sklearn pipeline to evaluate.
param_grid : dict of str, default=None
The key of the dictionary must be the same as the associated pipeline.
postprocess_pipeline : Pipeline | None, default=None
Optional pipeline to apply to the data after preprocessing.

Returns
-------
results : pd.DataFrame
A dataframe containing the results.
"""
from moabb.evaluations.splitters import CrossDatasetSplitter

if not isinstance(pipelines, dict):
raise ValueError("pipelines must be a dict")

for _, pipeline in pipelines.items():
if not isinstance(pipeline, BaseEstimator):
raise ValueError("pipelines must only contain Pipeline instances")

# Build a process pipeline from the first dataset (all are now matched)
process_pipeline = self.paradigm.make_process_pipelines(
self.datasets[0],
return_epochs=self.return_epochs,
return_raws=self.return_raws,
postprocess_pipeline=postprocess_pipeline,
)[0]

# Load and concatenate data from all datasets
all_X, all_y, all_metadata = [], [], []

# Build a mapping from dataset code to dataset object
ds_code_to_obj = {}

for ds in self.datasets:
X, y, metadata = self.paradigm.get_data(
dataset=ds,
return_epochs=self.return_epochs,
return_raws=self.return_raws,
cache_config=self.cache_config,
postprocess_pipeline=postprocess_pipeline,
)
metadata = metadata.copy()
metadata["dataset"] = ds.code
ds_code_to_obj[ds.code] = ds

all_X.append(X)
all_y.append(y)
all_metadata.append(metadata)

X = np.concatenate(all_X, axis=0)
y = np.concatenate(all_y, axis=0)
metadata = pd.concat(all_metadata, ignore_index=True)

le = LabelEncoder()
y_encoded = y if self.mne_labels else le.fit_transform(y)

# Create the splitter
train_codes = [ds.code for ds in self.train_datasets]
test_codes = [ds.code for ds in self.test_datasets]
splitter = CrossDatasetSplitter(
train_datasets=train_codes,
test_datasets=test_codes,
)

scorer = get_scorer(self.paradigm.scoring)

for train_idx, test_idx in splitter.split(y_encoded, metadata):
X_train, y_train = X[train_idx], y_encoded[train_idx]
X_test, y_test = X[test_idx], y_encoded[test_idx]

test_metadata = metadata.iloc[test_idx]
test_ds_code = test_metadata["dataset"].iloc[0]
test_subject = test_metadata["subject"].iloc[0]
test_dataset_obj = ds_code_to_obj[test_ds_code]

# Score per session within this subject
sessions = test_metadata["session"].unique()

for name, clf in pipelines.items():
if _carbonfootprint:
tracker = EmissionsTracker(save_to_file=False, log_level="error")
tracker.start()

t_start = time()
model = clone(clf).fit(X_train, y_train)

if _carbonfootprint:
emissions = tracker.stop()
if emissions is None:
emissions = np.nan

duration = time() - t_start
nchan = (
X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1]
)

for session in sessions:
sess_mask = test_metadata["session"] == session
sess_idx = np.where(sess_mask)[0]
score = scorer(model, X_test[sess_idx], y_test[sess_idx])

res = {
"time": duration,
"dataset": test_dataset_obj,
"subject": test_subject,
"session": session,
"score": score,
"n_samples": len(y_train),
"n_channels": nchan,
"pipeline": name,
}
if _carbonfootprint:
res["carbon_emission"] = 1000 * emissions

self.push_result(res, pipelines, process_pipeline)

return self.results.to_dataframe(
pipelines=pipelines, process_pipeline=process_pipeline
)

def evaluate(
self, dataset, pipelines, param_grid, process_pipeline, postprocess_pipeline=None
):
"""Not used directly — use :meth:`process` instead.

This method satisfies the abstract interface but is not called
because :meth:`process` is overridden.
"""
raise NotImplementedError(
"CrossDatasetEvaluation.evaluate() should not be called directly. "
"Use process() instead."
)

def is_valid(self, dataset):
"""Check if dataset is valid for this evaluation.

Always returns True because multi-dataset validation is
enforced in ``__init__``, not per-dataset.
"""
return True
79 changes: 79 additions & 0 deletions moabb/evaluations/splitters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import logging

import numpy as np
from sklearn.model_selection import (
BaseCrossValidator,
LeaveOneGroupOut,
Expand Down Expand Up @@ -267,3 +268,81 @@ def split(self, y, metadata):
yield subject_indices[train_session_idx], subject_indices[
test_session_idx
]


class CrossDatasetSplitter(BaseCrossValidator):
"""Data splitter for cross-dataset evaluation.

This splitter enables cross-dataset evaluation by splitting data based on
dataset membership. All samples from training datasets are used as training
data, and test splits are created per-subject within each test dataset.

Parameters
----------
train_datasets : list of str
List of dataset codes to use for training.
test_datasets : list of str
List of dataset codes to use for testing.

Yields
------
train : ndarray
The training set indices for that split (all samples from train datasets).
test : ndarray
The testing set indices for that split (one subject from one test dataset).
"""

def __init__(self, train_datasets, test_datasets):
self.train_datasets = train_datasets
self.test_datasets = test_datasets

def get_n_splits(self, y=None, metadata=None):
"""Return the number of splits.

The number of splits equals the number of unique (dataset, subject)
pairs in the test datasets.

Parameters
----------
y : array-like, default=None
Ignored, present for API compatibility.
metadata : pd.DataFrame
Must contain 'dataset' and 'subject' columns.

Returns
-------
n_splits : int
The number of splits.
"""
n_splits = 0
for test_code in self.test_datasets:
ds_mask = metadata["dataset"] == test_code
n_splits += metadata.loc[ds_mask, "subject"].nunique()
return n_splits

def split(self, y, metadata):
"""Generate train/test indices for cross-dataset evaluation.

Parameters
----------
y : array-like
Target variable (unused, present for API compatibility).
metadata : pd.DataFrame
Must contain 'dataset' and 'subject' columns.

Yields
------
train_indices : ndarray
Indices of training samples (all samples from train datasets).
test_indices : ndarray
Indices of test samples (one subject from one test dataset).
"""
train_mask = metadata["dataset"].isin(self.train_datasets)
train_indices = np.where(train_mask)[0]

for test_code in self.test_datasets:
ds_mask = metadata["dataset"] == test_code
for subject in metadata.loc[ds_mask, "subject"].unique():
subj_mask = ds_mask & (metadata["subject"] == subject)
test_indices = np.where(subj_mask)[0]
yield train_indices, test_indices
Loading