|
4 | 4 | ===================================================== |
5 | 5 | """ |
6 | 6 |
|
| 7 | + |
7 | 8 | # Authors: Thomas, Kooiman, Radovan Vodila, Jorge Sanmartin Martinez, and Paul Verhoeven |
8 | 9 | # |
9 | 10 | # License: BSD (3-clause) |
10 | 11 |
|
11 | 12 |
|
12 | 13 | ############################################################################### |
13 | | -# Why would I want to split my data within a session? |
14 | | -############################################################################### |
| 14 | +# The justification and goal for within-session splitting |
| 15 | +# ----------------------- |
15 | 16 | # In short, because we want to prevent the model from recognizing the subject |
16 | 17 | # and learning subject-specific representations instead of focusing on the task at hand. |
17 | 18 | # |
|
28 | 29 | # This approach forms a critical foundation in the MOABB evaluation framework, |
29 | 30 | # which supports three levels of model generalization: |
30 | 31 | # |
31 | | -# - Within-session: test generalization across trials within a single session |
32 | | -# - Cross-session: test generalization across different recording sessions |
33 | | -# - Cross-subject: test generalization across different brains |
| 32 | +# - Within-session: test generalization across trials within a single session |
| 33 | +# - Cross-session: test generalization across different recording sessions |
| 34 | +# - Cross-subject: test generalization across different brains |
34 | 35 | # |
35 | 36 | # Where Within-session and cross-session are generalized across the same subject, cross-subject is generalized between (groups of) subjects. |
36 | 37 | # |
37 | | -# Each level decreases in specialization, moving from highly subject-specific models |
| 38 | +# Each level decreases in specialization, moving from highly subject-specific models, |
38 | 39 | # to those that can generalize across individuals. |
39 | 40 | # |
40 | 41 | # This tutorial focuses on within-session evaluation to establish a reliable |
41 | 42 | # baseline for model performance before attempting more challenging generalization tasks. |
42 | 43 |
|
43 | | -import warnings |
44 | 44 |
|
| 45 | +############################################################################### |
| 46 | +# Importing the necessary libraries |
| 47 | +# ----------------------- |
| 48 | + |
| 49 | + |
| 50 | +import warnings |
45 | 51 | import matplotlib.pyplot as plt |
46 | 52 |
|
47 | 53 | # Standard imports |
|
65 | 71 | warnings.filterwarnings("ignore") |
66 | 72 | moabb.set_log_level("info") |
67 | 73 |
|
| 74 | + |
68 | 75 | ############################################################################### |
69 | | -# Load the dataset and paradigm |
70 | | -############################################################################### |
71 | | -# We use the BNCI2014_001 dataset: BCI Comp IV dataset 2a (motor imagery) |
| 76 | +# Load the dataset |
| 77 | +# ----------------------- |
| 78 | +# In this example we use 3 subjects of the :class:`moabb.datasets.BNCI2014_001` dataset. |
| 79 | + |
| 80 | + |
72 | 81 | dataset = BNCI2014_001() |
73 | | -# Restrict to a few subjects to keep runtime reasonable for demonstration |
74 | 82 | dataset.subject_list = [1, 2, 3] |
75 | 83 |
|
76 | | -# Define the paradigm: here, left vs right hand imagery |
77 | | -paradigm = LeftRightImagery() |
78 | 84 |
|
79 | 85 | ############################################################################### |
80 | | -# Extract data: epochs (X), labels (y), and trial metadata (meta) |
81 | | -############################################################################### |
| 86 | +# Extract data: epochs (X), labels (y), and trial metadata (meta) |
| 87 | +# ----------------------- |
| 88 | +# For this dataset we use the :class:`moabb.paradigms.LeftRightImagery` paradigm. |
| 89 | +# Additionally, we use the `get_data` method to download, preprocess, epoch, and label the data. |
| 90 | + |
| 91 | + |
| 92 | +paradigm = LeftRightImagery() |
82 | 93 | # This call downloads (if needed), preprocesses, epochs, and labels the data |
83 | 94 | X, y, meta = paradigm.get_data(dataset=dataset, subjects=dataset.subject_list) |
84 | 95 |
|
|
88 | 99 | print("meta shape (trials, info columns):", meta.shape) |
89 | 100 | print(meta.head()) # shows subject/session for each trial |
90 | 101 |
|
91 | | -# Plot a small epoch (e.g., the first trial) (3 channels for simplicity sake) |
| 102 | + |
| 103 | +############################################################################### |
| 104 | +# Vizualising a single epoch. |
| 105 | +# ----------------------- |
| 106 | +# Plot a single epoch (e.g., the first trial), to see what's in this dataset. (limiting to 3 channels for simplicity sake). |
| 107 | + |
| 108 | + |
92 | 109 | plt.figure(figsize=(10, 4)) |
93 | 110 | plt.plot(X[0][0:3].T) # Transpose to plot channels over time |
94 | 111 | plt.title("Epoch 0: EEG Channels Over Time") |
|
98 | 115 | plt.tight_layout() |
99 | 116 | plt.show() |
100 | 117 |
|
| 118 | + |
101 | 119 | ############################################################################### |
102 | | -# Build a classification pipeline: CSP to LDA |
103 | | -############################################################################### |
104 | | -# Common Spatial Patterns (CSP) finds spatial filters that maximize variance difference between classes |
105 | | -# Linear Discriminant Analysis (LDA) is a simple linear classifier on the CSP features |
| 120 | +# Build a classification pipeline: CSP to LDA |
| 121 | +# ----------------------- |
| 122 | +# We use Common Spatial Patterns (CSP) finds spatial filters that maximize variance difference between classes. |
| 123 | +# And then use Linear Discriminant Analysis (LDA) as a simple linear classifier on the extracted CSP features. |
| 124 | + |
| 125 | + |
106 | 126 | pipe = make_pipeline( |
107 | 127 | CSP(n_components=6, reg=None), # reduce to 6 CSP components |
108 | 128 | LDA(), # classify based on these features |
109 | 129 | ) |
110 | 130 | pipe |
111 | 131 |
|
| 132 | + |
112 | 133 | ############################################################################### |
113 | | -# Instantiate WithinSessionSplitter |
114 | | -############################################################################### |
115 | | -# We want 5-fold CV _within_ each subject × session grouping |
| 134 | +# Instantiate WithinSessionSplitter |
| 135 | +# ----------------------- |
| 136 | +# We want 5-fold cross-validation (CV) within each subject × session grouping |
| 137 | + |
| 138 | + |
116 | 139 | wss = WithinSessionSplitter(n_folds=5, shuffle=True, random_state=404) |
117 | 140 | print(f"Splitter config: folds={wss.n_folds}, shuffle={wss.shuffle}") |
118 | 141 |
|
|
124 | 147 | if wss.get_n_splits(meta) == 0: |
125 | 148 | raise RuntimeError("No splits generated: check that each subject has ≥2 sessions.") |
126 | 149 |
|
| 150 | + |
127 | 151 | ############################################################################### |
128 | 152 | # Manual evaluation loop: train/test each fold |
129 | | -############################################################################### |
| 153 | +# ----------------------- |
130 | 154 | # We'll collect one row per fold: which subject/session was held out and its score |
| 155 | + |
| 156 | + |
131 | 157 | records = [] |
132 | 158 | for fold_id, (train_idx, test_idx) in enumerate(wss.split(y, meta)): |
133 | 159 | # Slice our epoch array and labels |
|
166 | 192 |
|
167 | 193 | ############################################################################### |
168 | 194 | # Summary of results |
169 | | -############################################################################### |
| 195 | +# ----------------------- |
170 | 196 | # We can quickly see per-subject, per-session performance: |
171 | | -summary = df.groupby(["subject", "session"])["score"].agg(["mean", "std"]).reset_index() |
172 | | -print("\nSummary of within-session fold scores (mean ± std):") |
173 | | -print(summary) |
174 | 197 | # We see subject 2’s Session 1 has lower mean accuracy, suggesting session variability. |
175 | 198 | # Note: you could plot these numbers to visually compare sessions, |
176 | 199 | # but here we print them to focus on the splitting logic itself. |
177 | 200 |
|
| 201 | + |
| 202 | +summary = df.groupby(["subject", "session"])["score"].agg(["mean", "std"]).reset_index() |
| 203 | +print("\nSummary of within-session fold scores (mean ± std):") |
| 204 | +print(summary) |
| 205 | + |
| 206 | + |
| 207 | + |
178 | 208 | ########################################################################## |
179 | | -# Plot results |
180 | | -########################################################################## |
| 209 | +# Visualisation of the results |
| 210 | +# ----------------------- |
181 | 211 |
|
182 | 212 |
|
183 | 213 | df["subject"] = df["subject"].astype(str) |
|
192 | 222 |
|
193 | 223 | ############################################################################### |
194 | 224 | # Visualisation of the data split |
195 | | -############################################################################### |
| 225 | +# ----------------------- |
| 226 | +# For our 3 subjects, we see that each subject has 5 folds of training data. |
| 227 | + |
| 228 | + |
196 | 229 | def plot_subject_split(ax, df): |
197 | 230 | """Create a bar plot showing the split of subject data into train and test.""" |
198 | 231 | colors = ["#3A6190", "#DDF2FF"] # Colors for train and test |
@@ -224,4 +257,3 @@ def plot_subject_split(ax, df): |
224 | 257 | # Add the subject split plot to the figure |
225 | 258 | plot_subject_split(ax, df) |
226 | 259 |
|
227 | | -# For our 3 subjects, we see that each subject has 5 folds of training data. |
|
0 commit comments