Skip to content

Commit 5232ac8

Browse files
author
PaulusBoskabouter
committed
Reworked and reformatted the markdown
1 parent 91051d0 commit 5232ac8

1 file changed

Lines changed: 64 additions & 32 deletions

File tree

examples/how_to_benchmark/tutorial_6_within_session_splitter.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
=====================================================
55
"""
66

7+
78
# Authors: Thomas, Kooiman, Radovan Vodila, Jorge Sanmartin Martinez, and Paul Verhoeven
89
#
910
# License: BSD (3-clause)
1011

1112

1213
###############################################################################
13-
# Why would I want to split my data within a session?
14-
###############################################################################
14+
# The justification and goal for within-session splitting
15+
# -----------------------
1516
# In short, because we want to prevent the model from recognizing the subject
1617
# and learning subject-specific representations instead of focusing on the task at hand.
1718
#
@@ -28,20 +29,25 @@
2829
# This approach forms a critical foundation in the MOABB evaluation framework,
2930
# which supports three levels of model generalization:
3031
#
31-
# - Within-session: test generalization across trials within a single session
32-
# - Cross-session: test generalization across different recording sessions
33-
# - Cross-subject: test generalization across different brains
32+
# - Within-session: test generalization across trials within a single session
33+
# - Cross-session: test generalization across different recording sessions
34+
# - Cross-subject: test generalization across different brains
3435
#
3536
# Where Within-session and cross-session are generalized across the same subject, cross-subject is generalized between (groups of) subjects.
3637
#
37-
# Each level decreases in specialization, moving from highly subject-specific models
38+
# Each level decreases in specialization, moving from highly subject-specific models,
3839
# to those that can generalize across individuals.
3940
#
4041
# This tutorial focuses on within-session evaluation to establish a reliable
4142
# baseline for model performance before attempting more challenging generalization tasks.
4243

43-
import warnings
4444

45+
###############################################################################
46+
# Importing the necessary libraries
47+
# -----------------------
48+
49+
50+
import warnings
4551
import matplotlib.pyplot as plt
4652

4753
# Standard imports
@@ -65,20 +71,25 @@
6571
warnings.filterwarnings("ignore")
6672
moabb.set_log_level("info")
6773

74+
6875
###############################################################################
69-
# Load the dataset and paradigm
70-
###############################################################################
71-
# We use the BNCI2014_001 dataset: BCI Comp IV dataset 2a (motor imagery)
76+
# Load the dataset
77+
# -----------------------
78+
# In this example we use 3 subjects of the :class:`moabb.datasets.BNCI2014_001` dataset.
79+
80+
7281
dataset = BNCI2014_001()
73-
# Restrict to a few subjects to keep runtime reasonable for demonstration
7482
dataset.subject_list = [1, 2, 3]
7583

76-
# Define the paradigm: here, left vs right hand imagery
77-
paradigm = LeftRightImagery()
7884

7985
###############################################################################
80-
# Extract data: epochs (X), labels (y), and trial metadata (meta)
81-
###############################################################################
86+
# Extract data: epochs (X), labels (y), and trial metadata (meta)
87+
# -----------------------
88+
# For this dataset we use the :class:`moabb.paradigms.LeftRightImagery` paradigm.
89+
# Additionally, we use the `get_data` method to download, preprocess, epoch, and label the data.
90+
91+
92+
paradigm = LeftRightImagery()
8293
# This call downloads (if needed), preprocesses, epochs, and labels the data
8394
X, y, meta = paradigm.get_data(dataset=dataset, subjects=dataset.subject_list)
8495

@@ -88,7 +99,13 @@
8899
print("meta shape (trials, info columns):", meta.shape)
89100
print(meta.head()) # shows subject/session for each trial
90101

91-
# Plot a small epoch (e.g., the first trial) (3 channels for simplicity sake)
102+
103+
###############################################################################
104+
# Vizualising a single epoch.
105+
# -----------------------
106+
# Plot a single epoch (e.g., the first trial), to see what's in this dataset. (limiting to 3 channels for simplicity sake).
107+
108+
92109
plt.figure(figsize=(10, 4))
93110
plt.plot(X[0][0:3].T) # Transpose to plot channels over time
94111
plt.title("Epoch 0: EEG Channels Over Time")
@@ -98,21 +115,27 @@
98115
plt.tight_layout()
99116
plt.show()
100117

118+
101119
###############################################################################
102-
# Build a classification pipeline: CSP to LDA
103-
###############################################################################
104-
# Common Spatial Patterns (CSP) finds spatial filters that maximize variance difference between classes
105-
# Linear Discriminant Analysis (LDA) is a simple linear classifier on the CSP features
120+
# Build a classification pipeline: CSP to LDA
121+
# -----------------------
122+
# We use Common Spatial Patterns (CSP) finds spatial filters that maximize variance difference between classes.
123+
# And then use Linear Discriminant Analysis (LDA) as a simple linear classifier on the extracted CSP features.
124+
125+
106126
pipe = make_pipeline(
107127
CSP(n_components=6, reg=None), # reduce to 6 CSP components
108128
LDA(), # classify based on these features
109129
)
110130
pipe
111131

132+
112133
###############################################################################
113-
# Instantiate WithinSessionSplitter
114-
###############################################################################
115-
# We want 5-fold CV _within_ each subject × session grouping
134+
# Instantiate WithinSessionSplitter
135+
# -----------------------
136+
# We want 5-fold cross-validation (CV) within each subject × session grouping
137+
138+
116139
wss = WithinSessionSplitter(n_folds=5, shuffle=True, random_state=404)
117140
print(f"Splitter config: folds={wss.n_folds}, shuffle={wss.shuffle}")
118141

@@ -124,10 +147,13 @@
124147
if wss.get_n_splits(meta) == 0:
125148
raise RuntimeError("No splits generated: check that each subject has ≥2 sessions.")
126149

150+
127151
###############################################################################
128152
# Manual evaluation loop: train/test each fold
129-
###############################################################################
153+
# -----------------------
130154
# We'll collect one row per fold: which subject/session was held out and its score
155+
156+
131157
records = []
132158
for fold_id, (train_idx, test_idx) in enumerate(wss.split(y, meta)):
133159
# Slice our epoch array and labels
@@ -166,18 +192,22 @@
166192

167193
###############################################################################
168194
# Summary of results
169-
###############################################################################
195+
# -----------------------
170196
# We can quickly see per-subject, per-session performance:
171-
summary = df.groupby(["subject", "session"])["score"].agg(["mean", "std"]).reset_index()
172-
print("\nSummary of within-session fold scores (mean ± std):")
173-
print(summary)
174197
# We see subject 2’s Session 1 has lower mean accuracy, suggesting session variability.
175198
# Note: you could plot these numbers to visually compare sessions,
176199
# but here we print them to focus on the splitting logic itself.
177200

201+
202+
summary = df.groupby(["subject", "session"])["score"].agg(["mean", "std"]).reset_index()
203+
print("\nSummary of within-session fold scores (mean ± std):")
204+
print(summary)
205+
206+
207+
178208
##########################################################################
179-
# Plot results
180-
##########################################################################
209+
# Visualisation of the results
210+
# -----------------------
181211

182212

183213
df["subject"] = df["subject"].astype(str)
@@ -192,7 +222,10 @@
192222

193223
###############################################################################
194224
# Visualisation of the data split
195-
###############################################################################
225+
# -----------------------
226+
# For our 3 subjects, we see that each subject has 5 folds of training data.
227+
228+
196229
def plot_subject_split(ax, df):
197230
"""Create a bar plot showing the split of subject data into train and test."""
198231
colors = ["#3A6190", "#DDF2FF"] # Colors for train and test
@@ -224,4 +257,3 @@ def plot_subject_split(ax, df):
224257
# Add the subject split plot to the figure
225258
plot_subject_split(ax, df)
226259

227-
# For our 3 subjects, we see that each subject has 5 folds of training data.

0 commit comments

Comments
 (0)