From 3d246894518e7d2c6e4a56eee3a33a9bf4aca4f1 Mon Sep 17 00:00:00 2001 From: msaunders804 Date: Sun, 12 Apr 2026 13:39:31 -0500 Subject: [PATCH 1/6] add WESAD dataset class and tests --- pyhealth/datasets/wesad.py | 236 ++++++++++++++++++++++++++++++++++++ tests/test_wesad_dataset.py | 103 ++++++++++++++++ 2 files changed, 339 insertions(+) create mode 100644 pyhealth/datasets/wesad.py create mode 100644 tests/test_wesad_dataset.py diff --git a/pyhealth/datasets/wesad.py b/pyhealth/datasets/wesad.py new file mode 100644 index 000000000..674df6b53 --- /dev/null +++ b/pyhealth/datasets/wesad.py @@ -0,0 +1,236 @@ +""" +PyHealth dataset for the WESAD (Wearable Stress and Affect Detection) dataset. + +Dataset link: + https://archive.ics.uci.edu/dataset/465/wesad+wearable+stress+and+affect+detection + +Dataset paper: (please cite if you use this dataset) + Schmidt, P., Reiss, A., Duerichen, R., Marberger, C., & Van Laerhoven, K. + "Introducing WESAD, a Multimodal Dataset for Wearable Stress and Affect + Detection." Proceedings of the 20th ACM International Conference on + Multimodal Interaction, 2018, pp. 400-408. + +Dataset paper link: + https://dl.acm.org/doi/10.1145/3242969.3242985 + +Authors: + Megan Saunders, Jennifer Miranda, Jesus Torres + {meganas4, jm123, jesusst2}@illinois.edu +""" + +import logging +import os +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from pyhealth.datasets import BaseDataset + +logger = logging.getLogger(__name__) + +# Affective state labels as defined in the WESAD protocol +WESAD_LABELS = { + 0: "not_defined", + 1: "baseline", + 2: "stress", + 3: "amusement", + 4: "meditation", +} + +# Subjects in the dataset (S1 was discarded by original authors) +WESAD_SUBJECTS = [ + "S2", "S3", "S4", "S5", "S6", + "S7", "S8", "S9", "S10", "S11", + "S13", "S14", "S15", "S16", "S17", +] + +# EDA sampling rate from wrist-worn Empatica E4 device (Hz) +EDA_SAMPLE_RATE = 4 + + +class WESADDataset(BaseDataset): + """Dataset class for the WESAD wearable stress and affect detection dataset. + + WESAD contains physiological and motion data from 15 subjects recorded + during a lab study with three affective states: baseline, stress, and + amusement. This class extracts the electrodermal activity (EDA) signal + from the wrist-worn Empatica E4 device and windows it into fixed-length + segments for downstream classification. + + Attributes: + root (str): Root directory of the raw data. + dataset_name (str): Name of the dataset. + config_path (str): Path to the configuration file. + window_size (int): Number of samples per window. + step_size (int): Step size between consecutive windows in samples. + label_map (Dict[int, int]): Mapping from WESAD label codes to + task-specific integer labels. + subjects (List[str]): List of subject IDs included in the dataset. + """ + + def __init__( + self, + root: str, + window_size: int = 60, + step_size: int = 10, + label_map: Optional[Dict[int, int]] = None, + subjects: Optional[List[str]] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + """Initializes the WESAD dataset. + + Args: + root (str): Root directory containing per-subject subdirectories, + each with a '.pkl' file. + window_size (int): Number of EDA samples per window. At 4 Hz, + 60 samples = 15 seconds. Defaults to 60. + step_size (int): Number of samples to advance between windows. + Defaults to 10. + label_map (Optional[Dict[int, int]]): Mapping from raw WESAD label + codes to output class integers. Raw codes are: 1=baseline, + 2=stress, 3=amusement. Defaults to binary stress detection: + {1: 0, 2: 1} (baseline=0, stress=1, amusement excluded). + subjects (Optional[List[str]]): List of subject IDs to load. + Defaults to all 15 subjects. + config_path (Optional[str]): Path to PyHealth config YAML. + + Raises: + FileNotFoundError: If root directory does not exist. + FileNotFoundError: If no subject pickle files are found in root. + + Example:: + >>> dataset = WESADDataset(root="./WESAD") + >>> print(len(dataset.samples)) + """ + self.window_size = window_size + self.step_size = step_size + self.label_map = label_map or {1: 0, 2: 1} # binary: baseline vs stress + self.subjects = subjects or WESAD_SUBJECTS + + self._verify_data(root) + self.samples = self._load_and_window(root) + + super().__init__( + root=root, + tables=["wesad"], + dataset_name="WESAD", + config_path=config_path, + **kwargs, + ) + + def _verify_data(self, root: str) -> None: + """Verifies the dataset directory structure. + + Args: + root (str): Root directory of the raw data. + + Raises: + FileNotFoundError: If root does not exist. + FileNotFoundError: If no subject pickle files are found. + """ + if not os.path.exists(root): + msg = f"Dataset root does not exist: {root}" + logger.error(msg) + raise FileNotFoundError(msg) + + pkl_files = list(Path(root).rglob("*.pkl")) + if not pkl_files: + msg = f"No .pkl files found under {root}. Ensure WESAD is downloaded and extracted." + logger.error(msg) + raise FileNotFoundError(msg) + + logger.info(f"Found {len(pkl_files)} subject pickle files.") + + def _load_subject( + self, root: str, subject_id: str + ) -> Tuple[np.ndarray, np.ndarray]: + """Loads EDA signal and labels for a single subject. + + Args: + root (str): Root directory of the raw data. + subject_id (str): Subject identifier (e.g. 'S2'). + + Returns: + Tuple[np.ndarray, np.ndarray]: EDA signal array of shape (N,) + and label array of shape (N,) at EDA sampling rate. + + Raises: + FileNotFoundError: If the subject pickle file does not exist. + """ + pkl_path = os.path.join(root, subject_id, f"{subject_id}.pkl") + if not os.path.exists(pkl_path): + msg = f"Subject file not found: {pkl_path}" + logger.error(msg) + raise FileNotFoundError(msg) + + with open(pkl_path, "rb") as f: + data = pickle.load(f, encoding="latin1") + + # EDA from wrist device (Empatica E4), shape (N,) + eda = data["signal"]["wrist"]["EDA"].flatten() + + # Labels are at chest device rate (700 Hz), downsample to EDA rate (4 Hz) + labels_chest = data["label"].flatten() + downsample_factor = len(labels_chest) // len(eda) + labels = labels_chest[::downsample_factor][: len(eda)] + + logger.info(f"Loaded subject {subject_id}: {len(eda)} EDA samples.") + return eda, labels + + def _load_and_window(self, root: str) -> List[Dict]: + """Loads all subjects and segments EDA into labeled windows. + + Windows with labels not present in self.label_map are discarded. + A window's label is assigned by majority vote over its samples. + + Args: + root (str): Root directory of the raw data. + + Returns: + List[Dict]: List of sample dicts with keys: + - 'subject_id' (str): Subject identifier. + - 'eda' (np.ndarray): EDA window of shape (window_size,). + - 'label' (int): Integer class label. + """ + samples = [] + + for subject_id in self.subjects: + try: + eda, labels = self._load_subject(root, subject_id) + except FileNotFoundError: + logger.warning(f"Skipping missing subject: {subject_id}") + continue + + # Slide window over signal + for start in range(0, len(eda) - self.window_size + 1, self.step_size): + end = start + self.window_size + window_eda = eda[start:end] + window_labels = labels[start:end] + + # Majority vote for window label + values, counts = np.unique(window_labels, return_counts=True) + majority_label = int(values[np.argmax(counts)]) + + # Skip windows with labels not in the label map + if majority_label not in self.label_map: + continue + + samples.append( + { + "subject_id": subject_id, + "eda": window_eda.astype(np.float32), + "label": self.label_map[majority_label], + } + ) + + logger.info(f"Total windows: {len(samples)}") + return samples + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int) -> Dict: + return self.samples[idx] \ No newline at end of file diff --git a/tests/test_wesad_dataset.py b/tests/test_wesad_dataset.py new file mode 100644 index 000000000..fcbb8cbd7 --- /dev/null +++ b/tests/test_wesad_dataset.py @@ -0,0 +1,103 @@ +"""Tests for the WESAD dataset class using synthetic data.""" + +import os +import pickle +import tempfile +import unittest + +import numpy as np + +from pyhealth.datasets.wesad import WESADDataset, EDA_SAMPLE_RATE + + +def _make_synthetic_subject(subject_id: str, root: str, n_seconds: int = 300) -> None: + """Creates a synthetic WESAD pickle file for testing.""" + n_eda = n_seconds * EDA_SAMPLE_RATE # 4 Hz + n_chest = n_eda * 175 # ~700 Hz chest device + + eda = np.random.rand(n_eda, 1).astype(np.float32) + + # Labels: first third baseline (1), second third stress (2), rest amusement (3) + labels = np.ones(n_chest, dtype=int) + labels[n_chest // 3: 2 * n_chest // 3] = 2 + labels[2 * n_chest // 3:] = 3 + + data = { + "signal": {"wrist": {"EDA": eda}}, + "label": labels, + } + + subject_dir = os.path.join(root, subject_id) + os.makedirs(subject_dir, exist_ok=True) + with open(os.path.join(subject_dir, f"{subject_id}.pkl"), "wb") as f: + pickle.dump(data, f) + + +class TestWESADDataset(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.subjects = ["S2", "S3"] + for sid in self.subjects: + _make_synthetic_subject(sid, self.tmp_dir) + + def test_loads_without_error(self): + dataset = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + window_size=60, + step_size=10, + ) + self.assertGreater(len(dataset), 0) + + def test_sample_shape(self): + dataset = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + window_size=60, + step_size=10, + ) + sample = dataset[0] + self.assertEqual(sample["eda"].shape, (60,)) + + def test_binary_labels_only(self): + dataset = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + label_map={1: 0, 2: 1}, + ) + labels = {s["label"] for s in dataset.samples} + self.assertTrue(labels.issubset({0, 1})) + + def test_three_class_label_map(self): + dataset = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + label_map={1: 0, 2: 1, 3: 2}, + ) + labels = {s["label"] for s in dataset.samples} + self.assertTrue(labels.issubset({0, 1, 2})) + + def test_missing_root_raises(self): + with self.assertRaises(FileNotFoundError): + WESADDataset(root="/nonexistent/path", subjects=self.subjects) + + def test_subject_id_in_sample(self): + dataset = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + ) + self.assertIn(dataset[0]["subject_id"], self.subjects) + + def test_window_size_respected(self): + for window_size in [30, 60, 120]: + dataset = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + window_size=window_size, + ) + self.assertEqual(dataset[0]["eda"].shape[0], window_size) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 07af41b29883764cec4b016811c15920217869f2 Mon Sep 17 00:00:00 2001 From: msaunders804 Date: Sun, 12 Apr 2026 13:40:37 -0500 Subject: [PATCH 2/6] add stress detection task and tests --- pyhealth/tasks/stress_detection.py | 104 +++++++++++++++++++++++++++++ tests/test_stress_detection.py | 81 ++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 pyhealth/tasks/stress_detection.py create mode 100644 tests/test_stress_detection.py diff --git a/pyhealth/tasks/stress_detection.py b/pyhealth/tasks/stress_detection.py new file mode 100644 index 000000000..694e19846 --- /dev/null +++ b/pyhealth/tasks/stress_detection.py @@ -0,0 +1,104 @@ +""" +Stress detection task for the WESAD dataset. + +Maps windowed EDA samples from WESADDataset to model-ready +input/output pairs for binary or three-class stress classification. + +Authors: + Megan Saunders, Jennifer Miranda, Jesus Torres + {meganas4, jm123, jesusst2}@illinois.edu +""" + +import logging +from typing import Dict, List, Optional + +import numpy as np +import torch +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +class StressDetectionDataset(Dataset): + """PyTorch Dataset wrapping WESAD windows for stress detection. + + Takes the list of sample dicts produced by WESADDataset and returns + tensors suitable for PyHealth model training. + + Attributes: + samples (List[Dict]): Windowed EDA samples from WESADDataset. + subject_ids (List[str]): Unique subject identifiers present in samples. + num_classes (int): Number of output classes inferred from label set. + + Example:: + >>> from pyhealth.datasets.wesad import WESADDataset + >>> from pyhealth.tasks.stress_detection import StressDetectionDataset + >>> raw = WESADDataset(root="./WESAD") + >>> task = StressDetectionDataset(raw.samples) + >>> x, y = task[0] + >>> print(x.shape, y) + """ + + def __init__(self, samples: List[Dict], + subject_filter: Optional[List[str]] = None) -> None: + """Initializes the stress detection task dataset. + + Args: + samples (List[Dict]): List of dicts with keys 'eda' (np.ndarray), + 'label' (int), and 'subject_id' (str), as produced by + WESADDataset._load_and_window. + subject_filter (Optional[List[str]]): If provided, only include + samples from these subject IDs. Useful for LNSO cross-validation + splits. + """ + if subject_filter is not None: + samples = [s for s in samples if s["subject_id"] in subject_filter] + + self.samples = samples + self.subject_ids = sorted({s["subject_id"] for s in samples}) + labels = {s["label"] for s in samples} + self.num_classes = len(labels) + + logger.info( + f"StressDetectionDataset: {len(self.samples)} windows, " + f"{self.num_classes} classes, " + f"{len(self.subject_ids)} subjects." + ) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int): + """Returns a single EDA window and its label as tensors. + + Args: + idx (int): Sample index. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - eda: Float tensor of shape (window_size,) + - label: Long tensor scalar + """ + sample = self.samples[idx] + eda = torch.tensor(sample["eda"], dtype=torch.float32) + label = torch.tensor(sample["label"], dtype=torch.long) + return eda, label + + def get_subject_splits(self, test_subjects: List[str]): + """Returns train and test subsets by subject for LNSO cross-validation. + + Args: + test_subjects (List[str]): Subject IDs to hold out as test set. + + Returns: + Tuple[StressDetectionDataset, StressDetectionDataset]: + Train dataset (all subjects not in test_subjects) and + test dataset (only test_subjects). + + Example:: + >>> train_ds, test_ds = task.get_subject_splits(["S2", "S3"]) + """ + train_subjects = [s for s in self.subject_ids if s not in test_subjects] + train_ds = StressDetectionDataset(self.samples, subject_filter=train_subjects) + test_ds = StressDetectionDataset(self.samples, subject_filter=test_subjects) + return train_ds, test_ds \ No newline at end of file diff --git a/tests/test_stress_detection.py b/tests/test_stress_detection.py new file mode 100644 index 000000000..cd88686f6 --- /dev/null +++ b/tests/test_stress_detection.py @@ -0,0 +1,81 @@ +"""Tests for the stress detection task using synthetic data.""" + +import os +import pickle +import tempfile +import unittest + +import numpy as np +import torch + +from pyhealth.datasets.wesad import WESADDataset, EDA_SAMPLE_RATE +from pyhealth.tasks.stress_detection import StressDetectionDataset + + +def _make_synthetic_subject(subject_id: str, root: str, n_seconds: int = 300) -> None: + n_eda = n_seconds * EDA_SAMPLE_RATE + n_chest = n_eda * 175 + eda = np.random.rand(n_eda, 1).astype(np.float32) + labels = np.ones(n_chest, dtype=int) + labels[n_chest // 3: 2 * n_chest // 3] = 2 + labels[2 * n_chest // 3:] = 3 + data = {"signal": {"wrist": {"EDA": eda}}, "label": labels} + subject_dir = os.path.join(root, subject_id) + os.makedirs(subject_dir, exist_ok=True) + with open(os.path.join(subject_dir, f"{subject_id}.pkl"), "wb") as f: + pickle.dump(data, f) + + +class TestStressDetectionDataset(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.subjects = ["S2", "S3", "S4"] + for sid in self.subjects: + _make_synthetic_subject(sid, self.tmp_dir) + raw = WESADDataset( + root=self.tmp_dir, + subjects=self.subjects, + window_size=60, + step_size=10, + label_map={1: 0, 2: 1}, + ) + self.task = StressDetectionDataset(raw.samples) + + def test_len(self): + self.assertGreater(len(self.task), 0) + + def test_getitem_types(self): + eda, label = self.task[0] + self.assertIsInstance(eda, torch.Tensor) + self.assertIsInstance(label, torch.Tensor) + + def test_eda_shape(self): + eda, _ = self.task[0] + self.assertEqual(eda.shape, (60,)) + + def test_label_dtype(self): + _, label = self.task[0] + self.assertEqual(label.dtype, torch.long) + + def test_subject_filter(self): + filtered = StressDetectionDataset( + self.task.samples, subject_filter=["S2"] + ) + subjects_in_filtered = {s["subject_id"] for s in filtered.samples} + self.assertEqual(subjects_in_filtered, {"S2"}) + + def test_lnso_split(self): + train_ds, test_ds = self.task.get_subject_splits(test_subjects=["S2"]) + train_subjects = {s["subject_id"] for s in train_ds.samples} + test_subjects = {s["subject_id"] for s in test_ds.samples} + self.assertNotIn("S2", train_subjects) + self.assertEqual(test_subjects, {"S2"}) + self.assertEqual(len(train_ds) + len(test_ds), len(self.task)) + + def test_num_classes_binary(self): + self.assertEqual(self.task.num_classes, 2) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 669b243d6c48f4af6b7160e8e632700ad10b3af5 Mon Sep 17 00:00:00 2001 From: msaunders804 Date: Sun, 12 Apr 2026 13:51:43 -0500 Subject: [PATCH 3/6] add ContrastiveEDA Model,. augmentations, NCE loss, and tests --- pyhealth/models/contrastive_eda.py | 503 +++++++++++++++++++++++++++++ tests/test_contrastive_eda.py | 184 +++++++++++ 2 files changed, 687 insertions(+) create mode 100644 pyhealth/models/contrastive_eda.py create mode 100644 tests/test_contrastive_eda.py diff --git a/pyhealth/models/contrastive_eda.py b/pyhealth/models/contrastive_eda.py new file mode 100644 index 000000000..d59094fd9 --- /dev/null +++ b/pyhealth/models/contrastive_eda.py @@ -0,0 +1,503 @@ +""" +Contrastive EDA Encoder model for PyHealth. + +Implements the SimCLR-style contrastive pre-training framework from: + Matton, K., Lewis, R., Guttag, J., & Picard, R. (2023). + "Contrastive Learning of Electrodermal Activity Representations + for Stress Detection." CHIL 2023. + +Authors: + Megan Saunders, Jennifer Miranda, Jesus Torres + {meganas4, jm123, jesusst2}@illinois.edu +""" + +import copy +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import scipy.signal +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.interpolate import CubicSpline + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Augmentations +# --------------------------------------------------------------------------- + +class GaussianNoise: + """Adds Gaussian noise scaled to signal power.""" + def __init__(self, sigma_scale: float = 0.1): + self.sigma_scale = sigma_scale + + def __call__(self, x: np.ndarray) -> np.ndarray: + sigma = np.mean(np.abs(x - np.mean(x))) * self.sigma_scale + return x + np.random.normal(scale=sigma, size=len(x)) + + +class TemporalCutout: + """Zeros out a random contiguous segment of the signal.""" + def __init__(self, cutout_size: int = 10): + self.cutout_size = cutout_size + + def __call__(self, x: np.ndarray) -> np.ndarray: + x = copy.deepcopy(x) + start = np.random.randint(0, max(1, len(x) - self.cutout_size)) + x[start:start + self.cutout_size] = 0.0 + return x + + +class ExtractPhasic: + """Extracts the phasic (high-frequency) component of EDA.""" + def __init__(self, data_freq: int = 4, cutoff_hz: float = 0.05): + self.data_freq = data_freq + self.b, self.a = scipy.signal.butter( + 4, [cutoff_hz], btype="highpass", output="ba", fs=data_freq + ) + + def __call__(self, x: np.ndarray) -> np.ndarray: + return scipy.signal.filtfilt(self.b, self.a, x) + + +class ExtractTonic: + """Extracts the tonic (low-frequency) component of EDA.""" + def __init__(self, data_freq: int = 4, cutoff_hz: float = 0.05): + self.data_freq = data_freq + self.b, self.a = scipy.signal.butter( + 4, [cutoff_hz], btype="lowpass", output="ba", fs=data_freq + ) + + def __call__(self, x: np.ndarray) -> np.ndarray: + return scipy.signal.filtfilt(self.b, self.a, x) + + +class LooseSensorArtifact: + """Simulates a loose sensor dropout artifact.""" + def __init__(self, width: int = 4, smooth_width: int = 2): + self.width = width + self.smooth_width = smooth_width + + def __call__(self, x: np.ndarray) -> np.ndarray: + x = copy.deepcopy(x) + artifact_width = self.width + if len(x) <= artifact_width: + return x + artifact_start = np.random.randint(0, len(x) - artifact_width + 1) + artifact_end = artifact_start + artifact_width - 1 + drop_start = artifact_start + self.smooth_width + drop_end = artifact_end - self.smooth_width + if drop_start >= drop_end: + return x + mean_amp = np.mean(x[drop_start:drop_end + 1]) + x[drop_start:drop_end + 1] -= mean_amp + x[x < 0] = 0.0 + return x + + +class AmplitudeScaling: + """Scales signal amplitude by a random constant factor.""" + def __init__(self, scale_min: float = 0.5, scale_max: float = 1.5): + self.scale_min = scale_min + self.scale_max = scale_max + + def __call__(self, x: np.ndarray) -> np.ndarray: + scale = np.random.uniform(self.scale_min, self.scale_max) + return x * scale + + +# Augmentation registry: name -> class +AUGMENTATION_REGISTRY = { + "gaussian_noise": GaussianNoise, + "temporal_cutout": TemporalCutout, + "extract_phasic": ExtractPhasic, + "extract_tonic": ExtractTonic, + "loose_sensor_artifact": LooseSensorArtifact, + "amplitude_scaling": AmplitudeScaling, +} + +# Preset augmentation groups for ablation +AUGMENTATION_GROUPS = { + "full": [ + "gaussian_noise", "temporal_cutout", "amplitude_scaling", + "extract_phasic", "extract_tonic", "loose_sensor_artifact", + ], + "generic_only": [ + "gaussian_noise", "temporal_cutout", "amplitude_scaling", + ], + "eda_specific_only": [ + "extract_phasic", "extract_tonic", "loose_sensor_artifact", + ], +} + + +def apply_augmentation_pair( + x: np.ndarray, + augmentation_names: List[str], +) -> Tuple[np.ndarray, np.ndarray]: + """Applies two independently sampled augmentations to produce a positive pair. + + Args: + x: Raw EDA window as numpy array of shape (window_size,). + augmentation_names: List of augmentation names to sample from. + + Returns: + Tuple of two augmented views, each of shape (window_size,). + """ + aug_classes = [AUGMENTATION_REGISTRY[n] for n in augmentation_names] + + def _apply_one(signal): + aug_fn = np.random.choice(aug_classes)() + out = aug_fn(signal) + # ensure output length matches input (some filters may shift length) + if len(out) != len(signal): + out = out[:len(signal)] + return out.astype(np.float32) + + return _apply_one(x.copy()), _apply_one(x.copy()) + + +# --------------------------------------------------------------------------- +# NT-Xent / NCE Loss +# --------------------------------------------------------------------------- + +class NCELoss(nn.Module): + """Noise Contrastive Estimation loss for contrastive pre-training. + + Ported directly from the authors' loss/nt_xent.py implementation. + + Args: + temperature: Softmax temperature. Lower = sharper distribution. + """ + + def __init__(self, temperature: float = 0.1): + super().__init__() + self.temperature = temperature + + def forward( + self, + embeddings_v1: torch.Tensor, + embeddings_v2: torch.Tensor, + ) -> torch.Tensor: + """Computes symmetric NCE loss between two sets of embeddings. + + Args: + embeddings_v1: View 1 embeddings, shape (N, D). + embeddings_v2: View 2 embeddings, shape (N, D). + + Returns: + Scalar loss tensor. + """ + norm1 = embeddings_v1.norm(dim=1).unsqueeze(0) + norm2 = embeddings_v2.norm(dim=1).unsqueeze(0) + sim_matrix = torch.mm(embeddings_v1, embeddings_v2.t()) + norm_matrix = torch.mm(norm1.t(), norm2) + sim_matrix = sim_matrix / (norm_matrix * self.temperature) + sim_matrix_exp = torch.exp(sim_matrix) + + # positive pairs are on the diagonal + pos_mask = torch.eye( + len(embeddings_v1), dtype=torch.bool, device=embeddings_v1.device + ) + + row_sum = sim_matrix_exp.sum(dim=1) + sim_row = sim_matrix_exp / row_sum.unsqueeze(1) + view1_loss = -torch.mean(torch.log(sim_row[pos_mask])) + + col_sum = sim_matrix_exp.sum(dim=0) + sim_col = sim_matrix_exp / col_sum.unsqueeze(0) + view2_loss = -torch.mean(torch.log(sim_col[pos_mask])) + + return (view1_loss + view2_loss) / 2.0 + + +# --------------------------------------------------------------------------- +# 1D CNN Encoder +# --------------------------------------------------------------------------- + +class EDAEncoder(nn.Module): + """Lightweight 1D CNN encoder for EDA windows. + + Architecture follows the authors' implementation: three convolutional + blocks with batch normalization and ReLU, followed by global average + pooling and a linear projection head for contrastive training. + + Args: + window_size: Length of the input EDA window in samples. + embed_dim: Dimension of the output embedding. + proj_dim: Dimension of the contrastive projection head output. + """ + + def __init__( + self, + window_size: int = 60, + embed_dim: int = 128, + proj_dim: int = 64, + ): + super().__init__() + self.window_size = window_size + self.embed_dim = embed_dim + self.proj_dim = proj_dim + + self.encoder = nn.Sequential( + # Block 1 + nn.Conv1d(1, 32, kernel_size=7, padding=3), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.MaxPool1d(2), + # Block 2 + nn.Conv1d(32, 64, kernel_size=5, padding=2), + nn.BatchNorm1d(64), + nn.ReLU(), + nn.MaxPool1d(2), + # Block 3 + nn.Conv1d(64, embed_dim, kernel_size=3, padding=1), + nn.BatchNorm1d(embed_dim), + nn.ReLU(), + ) + + # Global average pooling -> (N, embed_dim) + self.pool = nn.AdaptiveAvgPool1d(1) + + # Projection head for contrastive loss + self.projector = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.ReLU(), + nn.Linear(embed_dim, proj_dim), + ) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Returns encoder embeddings without projection head. + + Args: + x: Input tensor of shape (N, window_size). + + Returns: + Embeddings of shape (N, embed_dim). + """ + x = x.unsqueeze(1) # (N, 1, window_size) + x = self.encoder(x) # (N, embed_dim, T') + x = self.pool(x).squeeze(2) # (N, embed_dim) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns projected embeddings for contrastive loss. + + Args: + x: Input tensor of shape (N, window_size). + + Returns: + Projected embeddings of shape (N, proj_dim). + """ + h = self.encode(x) + return self.projector(h) + + +# --------------------------------------------------------------------------- +# Full ContrastiveEDAModel +# --------------------------------------------------------------------------- + +class ContrastiveEDAModel(nn.Module): + """Contrastive EDA encoder model for PyHealth. + + Supports two modes: + - **pretrain**: SimCLR-style contrastive pre-training with NT-Xent loss. + Input EDA windows are augmented into positive pairs and the encoder + is trained to bring them together in embedding space. + - **finetune**: Supervised stress detection. The encoder is loaded from + a pre-trained checkpoint, a linear classifier head is appended, and + the full network (or encoder-frozen variant) is fine-tuned with + cross-entropy loss. + + Args: + window_size: EDA window length in samples. + embed_dim: Encoder output dimension. + proj_dim: Contrastive projection head dimension. + num_classes: Number of output classes for finetune mode. + augmentation_group: One of 'full', 'generic_only', 'eda_specific_only', + or a custom list of augmentation names. Controls which augmentations + are applied during pre-training. Defaults to 'full'. + temperature: NT-Xent loss temperature. + freeze_encoder: If True in finetune mode, encoder weights are frozen + and only the classifier head is trained. + + Example:: + >>> model = ContrastiveEDAModel(window_size=60, num_classes=2) + >>> # Pretrain + >>> loss = model.pretrain_step(batch_eda) + >>> # Finetune + >>> model.set_finetune_mode(num_classes=2) + >>> logits = model(batch_eda) + """ + + def __init__( + self, + window_size: int = 60, + embed_dim: int = 128, + proj_dim: int = 64, + num_classes: int = 2, + augmentation_group: str = "full", + temperature: float = 0.1, + freeze_encoder: bool = False, + ): + super().__init__() + self.window_size = window_size + self.num_classes = num_classes + self.freeze_encoder = freeze_encoder + self._mode = "pretrain" + + # Resolve augmentation list + if isinstance(augmentation_group, list): + self.augmentation_names = augmentation_group + else: + if augmentation_group not in AUGMENTATION_GROUPS: + raise ValueError( + f"augmentation_group must be one of " + f"{list(AUGMENTATION_GROUPS.keys())} or a list of names. " + f"Got: {augmentation_group}" + ) + self.augmentation_names = AUGMENTATION_GROUPS[augmentation_group] + + self.encoder = EDAEncoder( + window_size=window_size, + embed_dim=embed_dim, + proj_dim=proj_dim, + ) + self.loss_fn = NCELoss(temperature=temperature) + + # Classifier head (added in finetune mode) + self.classifier: Optional[nn.Linear] = None + + # ------------------------------------------------------------------ + # Mode switching + # ------------------------------------------------------------------ + + def set_finetune_mode(self, num_classes: Optional[int] = None) -> None: + """Switches model to finetune mode and attaches classifier head. + + Args: + num_classes: Number of output classes. Uses self.num_classes + if not provided. + """ + self._mode = "finetune" + if num_classes is not None: + self.num_classes = num_classes + self.classifier = nn.Linear(self.encoder.embed_dim, self.num_classes) + if self.freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + logger.info("Encoder frozen. Training classifier head only.") + else: + logger.info("Fine-tuning full network.") + + def set_pretrain_mode(self) -> None: + """Switches model back to contrastive pre-training mode.""" + self._mode = "pretrain" + for param in self.encoder.parameters(): + param.requires_grad = True + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + In pretrain mode: returns projected embeddings of shape (N, proj_dim). + In finetune mode: returns class logits of shape (N, num_classes). + + Args: + x: EDA windows of shape (N, window_size). + + Returns: + Projected embeddings or logits depending on current mode. + """ + if self._mode == "pretrain": + return self.encoder(x) + else: + if self.classifier is None: + raise RuntimeError( + "Call set_finetune_mode() before running in finetune mode." + ) + h = self.encoder.encode(x) + return self.classifier(h) + + # ------------------------------------------------------------------ + # Training steps + # ------------------------------------------------------------------ + + def pretrain_step( + self, x: torch.Tensor + ) -> torch.Tensor: + """Computes contrastive loss for a batch of EDA windows. + + Augments each window into two views and computes NT-Xent loss + between the projected embeddings. + + Args: + x: EDA windows of shape (N, window_size). + + Returns: + Scalar contrastive loss tensor. + """ + device = x.device + x_np = x.cpu().numpy() + + views1, views2 = [], [] + for window in x_np: + v1, v2 = apply_augmentation_pair(window, self.augmentation_names) + views1.append(v1) + views2.append(v2) + + v1_tensor = torch.tensor(np.stack(views1), dtype=torch.float32).to(device) + v2_tensor = torch.tensor(np.stack(views2), dtype=torch.float32).to(device) + + z1 = self.encoder(v1_tensor) + z2 = self.encoder(v2_tensor) + + return self.loss_fn(z1, z2) + + def finetune_step( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes cross-entropy loss and logits for supervised fine-tuning. + + Args: + x: EDA windows of shape (N, window_size). + y: Integer class labels of shape (N,). + + Returns: + Tuple of (loss scalar, logits of shape (N, num_classes)). + """ + logits = self.forward(x) + loss = F.cross_entropy(logits, y) + return loss, logits + + # ------------------------------------------------------------------ + # Checkpoint helpers + # ------------------------------------------------------------------ + + def save_encoder(self, path: str) -> None: + """Saves encoder weights to disk. + + Args: + path: File path for the checkpoint (.pt file). + """ + torch.save(self.encoder.state_dict(), path) + logger.info(f"Encoder saved to {path}") + + def load_encoder(self, path: str, strict: bool = True) -> None: + """Loads encoder weights from a checkpoint. + + Args: + path: File path to the encoder checkpoint (.pt file). + strict: Whether to strictly enforce key matching. + """ + state = torch.load(path, map_location="cpu") + self.encoder.load_state_dict(state, strict=strict) + logger.info(f"Encoder loaded from {path}") \ No newline at end of file diff --git a/tests/test_contrastive_eda.py b/tests/test_contrastive_eda.py new file mode 100644 index 000000000..c7c5f26b4 --- /dev/null +++ b/tests/test_contrastive_eda.py @@ -0,0 +1,184 @@ +"""Tests for ContrastiveEDAModel using synthetic data.""" + +import os +import tempfile +import unittest + +import numpy as np +import torch + +from pyhealth.models.contrastive_eda import ( + ContrastiveEDAModel, + NCELoss, + EDAEncoder, + apply_augmentation_pair, + AUGMENTATION_GROUPS, + AUGMENTATION_REGISTRY, +) + + +class TestEDAEncoder(unittest.TestCase): + + def setUp(self): + self.window_size = 60 + self.batch_size = 8 + self.encoder = EDAEncoder(window_size=self.window_size) + self.x = torch.randn(self.batch_size, self.window_size) + + def test_encode_output_shape(self): + h = self.encoder.encode(self.x) + self.assertEqual(h.shape, (self.batch_size, self.encoder.embed_dim)) + + def test_forward_output_shape(self): + z = self.encoder(self.x) + self.assertEqual(z.shape, (self.batch_size, self.encoder.proj_dim)) + + def test_no_nan_in_output(self): + z = self.encoder(self.x) + self.assertFalse(torch.isnan(z).any()) + + +class TestNCELoss(unittest.TestCase): + + def test_loss_is_scalar(self): + loss_fn = NCELoss(temperature=0.1) + z1 = torch.randn(8, 64) + z2 = torch.randn(8, 64) + loss = loss_fn(z1, z2) + self.assertEqual(loss.shape, torch.Size([])) + + def test_loss_is_positive(self): + loss_fn = NCELoss(temperature=0.1) + z1 = torch.randn(8, 64) + z2 = torch.randn(8, 64) + loss = loss_fn(z1, z2) + self.assertGreater(loss.item(), 0) + + def test_identical_views_lower_loss(self): + loss_fn = NCELoss(temperature=0.1) + z = torch.randn(8, 64) + loss_same = loss_fn(z, z) + loss_diff = loss_fn(z, torch.randn(8, 64)) + self.assertLess(loss_same.item(), loss_diff.item()) + + +class TestAugmentations(unittest.TestCase): + + def setUp(self): + self.x = np.random.rand(60).astype(np.float32) + + def test_all_augmentations_preserve_shape(self): + for name, cls in AUGMENTATION_REGISTRY.items(): + with self.subTest(augmentation=name): + aug = cls() + out = aug(self.x.copy()) + self.assertEqual( + len(out), len(self.x), + f"{name} changed output length" + ) + + def test_augmentation_pair_shapes(self): + v1, v2 = apply_augmentation_pair(self.x, AUGMENTATION_GROUPS["full"]) + self.assertEqual(v1.shape, self.x.shape) + self.assertEqual(v2.shape, self.x.shape) + + def test_augmentation_pair_differs(self): + v1, v2 = apply_augmentation_pair(self.x, AUGMENTATION_GROUPS["full"]) + self.assertFalse(np.allclose(v1, v2)) + + +class TestContrastiveEDAModel(unittest.TestCase): + + def setUp(self): + self.window_size = 60 + self.batch_size = 8 + self.x = torch.randn(self.batch_size, self.window_size) + self.y = torch.randint(0, 2, (self.batch_size,)) + + def test_pretrain_step_returns_scalar(self): + model = ContrastiveEDAModel(window_size=self.window_size) + loss = model.pretrain_step(self.x) + self.assertEqual(loss.shape, torch.Size([])) + + def test_pretrain_step_loss_positive(self): + model = ContrastiveEDAModel(window_size=self.window_size) + loss = model.pretrain_step(self.x) + self.assertGreater(loss.item(), 0) + + def test_finetune_step_output_shapes(self): + model = ContrastiveEDAModel(window_size=self.window_size, num_classes=2) + model.set_finetune_mode() + loss, logits = model.finetune_step(self.x, self.y) + self.assertEqual(logits.shape, (self.batch_size, 2)) + self.assertEqual(loss.shape, torch.Size([])) + + def test_forward_pretrain_mode(self): + model = ContrastiveEDAModel(window_size=self.window_size) + z = model(self.x) + self.assertEqual(z.shape[0], self.batch_size) + + def test_forward_finetune_mode(self): + model = ContrastiveEDAModel(window_size=self.window_size, num_classes=2) + model.set_finetune_mode() + logits = model(self.x) + self.assertEqual(logits.shape, (self.batch_size, 2)) + + def test_forward_finetune_without_set_raises(self): + model = ContrastiveEDAModel(window_size=self.window_size) + model._mode = "finetune" + with self.assertRaises(RuntimeError): + model(self.x) + + def test_freeze_encoder(self): + model = ContrastiveEDAModel( + window_size=self.window_size, + freeze_encoder=True, + ) + model.set_finetune_mode() + for param in model.encoder.parameters(): + self.assertFalse(param.requires_grad) + self.assertTrue(model.classifier.weight.requires_grad) + + def test_augmentation_groups(self): + for group in ["full", "generic_only", "eda_specific_only"]: + with self.subTest(group=group): + model = ContrastiveEDAModel( + window_size=self.window_size, + augmentation_group=group, + ) + loss = model.pretrain_step(self.x) + self.assertGreater(loss.item(), 0) + + def test_invalid_augmentation_group_raises(self): + with self.assertRaises(ValueError): + ContrastiveEDAModel(augmentation_group="nonexistent") + + def test_save_load_encoder(self): + model = ContrastiveEDAModel(window_size=self.window_size) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "encoder.pt") + model.save_encoder(path) + model2 = ContrastiveEDAModel(window_size=self.window_size) + model2.load_encoder(path) + z1 = model.encoder.encode(self.x) + z2 = model2.encoder.encode(self.x) + self.assertTrue(torch.allclose(z1, z2)) + + def test_three_class_finetune(self): + model = ContrastiveEDAModel(window_size=self.window_size, num_classes=3) + model.set_finetune_mode() + y = torch.randint(0, 3, (self.batch_size,)) + loss, logits = model.finetune_step(self.x, y) + self.assertEqual(logits.shape, (self.batch_size, 3)) + + def test_mode_switching(self): + model = ContrastiveEDAModel(window_size=self.window_size, num_classes=2) + self.assertEqual(model._mode, "pretrain") + model.set_finetune_mode() + self.assertEqual(model._mode, "finetune") + model.set_pretrain_mode() + self.assertEqual(model._mode, "pretrain") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 282b32fa59eb563fe1a4cc99e85ea52ce0d78776 Mon Sep 17 00:00:00 2001 From: msaunders804 Date: Sun, 12 Apr 2026 13:53:57 -0500 Subject: [PATCH 4/6] register WESADDataset and ContrastiveEDAModel in package __init__ --- pyhealth/datasets/__init__.py | 1 + pyhealth/models/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..6f6393349 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .wesad import WESADDataset diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..7c405d3b6 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .contrastive_eda import ContrastiveEDAModel From 21ba378988a6932a69e64afcd4500b2a95274fb3 Mon Sep 17 00:00:00 2001 From: msaunders804 Date: Sun, 12 Apr 2026 13:56:10 -0500 Subject: [PATCH 5/6] add full pipeline example script with ablation study --- .../wesad_stress_detection_contrastive_eda.py | 380 ++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 examples/wesad_stress_detection_contrastive_eda.py diff --git a/examples/wesad_stress_detection_contrastive_eda.py b/examples/wesad_stress_detection_contrastive_eda.py new file mode 100644 index 000000000..9ba9457f0 --- /dev/null +++ b/examples/wesad_stress_detection_contrastive_eda.py @@ -0,0 +1,380 @@ +""" +Full pipeline example: Contrastive EDA pre-training and stress detection on WESAD. + +Reproduces the core experiment from: + Matton, K., Lewis, R., Guttag, J., & Picard, R. (2023). + "Contrastive Learning of Electrodermal Activity Representations + for Stress Detection." CHIL 2023. + +This script demonstrates: + 1. Loading and windowing the WESAD dataset + 2. Contrastive pre-training of the EDA encoder + 3. Fine-tuning for binary stress detection + 4. Ablation: full vs. generic-only vs. EDA-specific augmentations + +Usage: + python examples/wesad_stress_detection_contrastive_eda.py \ + --data_root /path/to/WESAD \ + --output_dir ./outputs \ + --augmentation_group full \ + --pretrain_epochs 50 \ + --finetune_epochs 20 \ + --label_fraction 0.01 + +Authors: + Megan Saunders, Jennifer Miranda, Jesus Torres + {meganas4, jm123, jesusst2}@illinois.edu +""" + +import argparse +import logging +import os +from typing import Dict, List + +import numpy as np +import torch +import torch.optim as optim +from sklearn.metrics import balanced_accuracy_score +from torch.utils.data import DataLoader + +from pyhealth.datasets import WESADDataset +from pyhealth.models import ContrastiveEDAModel +from pyhealth.tasks.stress_detection import StressDetectionDataset + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +# LNSO folds matching the authors' dataset_splits/WESAD/ files +LNSO_FOLDS = [ + ["S2", "S3"], + ["S4", "S5"], + ["S6", "S7"], + ["S8", "S9"], + ["S10", "S11"], +] + + +def pretrain( + model: ContrastiveEDAModel, + train_loader: DataLoader, + epochs: int, + device: torch.device, + lr: float = 1e-3, +) -> List[float]: + """Runs contrastive pre-training loop. + + Args: + model: ContrastiveEDAModel in pretrain mode. + train_loader: DataLoader yielding (eda, label) tuples. + epochs: Number of training epochs. + device: Torch device. + lr: Learning rate. + + Returns: + List of per-epoch training losses. + """ + model.set_pretrain_mode() + model.to(device) + optimizer = optim.Adam(model.parameters(), lr=lr) + losses = [] + + for epoch in range(epochs): + model.train() + epoch_loss = 0.0 + for x, _ in train_loader: + x = x.to(device) + optimizer.zero_grad() + loss = model.pretrain_step(x) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + avg = epoch_loss / len(train_loader) + losses.append(avg) + if (epoch + 1) % 10 == 0: + logger.info(f" Pretrain epoch {epoch+1}/{epochs} loss={avg:.4f}") + + return losses + + +def finetune_and_evaluate( + model: ContrastiveEDAModel, + train_ds: StressDetectionDataset, + test_ds: StressDetectionDataset, + label_fraction: float, + epochs: int, + device: torch.device, + lr: float = 1e-3, + batch_size: int = 64, + freeze_encoder: bool = False, +) -> float: + """Fine-tunes model on a fraction of labeled data and evaluates. + + Args: + model: ContrastiveEDAModel with pre-trained encoder. + train_ds: Training StressDetectionDataset. + test_ds: Test StressDetectionDataset. + label_fraction: Fraction of training labels to use (e.g. 0.01 = 1%). + epochs: Number of fine-tuning epochs. + device: Torch device. + lr: Learning rate. + batch_size: Batch size. + freeze_encoder: Whether to freeze encoder during fine-tuning. + + Returns: + Balanced accuracy on the test set. + """ + # Subsample labeled training data + n_labeled = max(1, int(len(train_ds) * label_fraction)) + indices = np.random.choice(len(train_ds), size=n_labeled, replace=False) + labeled_samples = [train_ds.samples[i] for i in indices] + labeled_ds = StressDetectionDataset(labeled_samples) + + train_loader = DataLoader(labeled_ds, batch_size=batch_size, shuffle=True) + test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False) + + model.freeze_encoder = freeze_encoder + model.set_finetune_mode(num_classes=2) + model.to(device) + + optimizer = optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=lr + ) + + for epoch in range(epochs): + model.train() + for x, y in train_loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + loss, _ = model.finetune_step(x, y) + loss.backward() + optimizer.step() + + # Evaluate + model.eval() + all_preds, all_labels = [], [] + with torch.no_grad(): + for x, y in test_loader: + x = x.to(device) + logits = model(x) + preds = logits.argmax(dim=1).cpu().numpy() + all_preds.extend(preds) + all_labels.extend(y.numpy()) + + return balanced_accuracy_score(all_labels, all_preds) + + +def run_fold( + fold_idx: int, + test_subjects: List[str], + all_samples: List[Dict], + augmentation_group: str, + pretrain_epochs: int, + finetune_epochs: int, + label_fraction: float, + device: torch.device, + output_dir: str, + window_size: int = 60, +) -> float: + """Runs one LNSO fold: pretrain on train subjects, evaluate on test subjects. + + Args: + fold_idx: Fold number for logging. + test_subjects: Subject IDs held out for testing. + all_samples: All windowed samples from WESADDataset. + augmentation_group: Augmentation group name for ContrastiveEDAModel. + pretrain_epochs: Number of contrastive pre-training epochs. + finetune_epochs: Number of supervised fine-tuning epochs. + label_fraction: Fraction of labeled training data to use. + device: Torch device. + output_dir: Directory to save encoder checkpoints. + window_size: EDA window size in samples. + + Returns: + Balanced accuracy for this fold. + """ + logger.info(f"\nFold {fold_idx} | test subjects: {test_subjects}") + + full_task = StressDetectionDataset(all_samples) + train_ds, test_ds = full_task.get_subject_splits(test_subjects) + + logger.info(f" Train windows: {len(train_ds)} | Test windows: {len(test_ds)}") + + # Pretrain + model = ContrastiveEDAModel( + window_size=window_size, + num_classes=2, + augmentation_group=augmentation_group, + ) + train_loader = DataLoader(train_ds, batch_size=64, shuffle=True) + pretrain(model, train_loader, epochs=pretrain_epochs, device=device) + + # Save encoder checkpoint + os.makedirs(output_dir, exist_ok=True) + ckpt_path = os.path.join(output_dir, f"encoder_fold{fold_idx}_{augmentation_group}.pt") + model.save_encoder(ckpt_path) + + # Finetune and evaluate + bal_acc = finetune_and_evaluate( + model=model, + train_ds=train_ds, + test_ds=test_ds, + label_fraction=label_fraction, + epochs=finetune_epochs, + device=device, + ) + + logger.info(f" Fold {fold_idx} balanced accuracy: {bal_acc:.4f}") + return bal_acc + + +def run_ablation( + all_samples: List[Dict], + pretrain_epochs: int, + finetune_epochs: int, + label_fraction: float, + device: torch.device, + output_dir: str, + window_size: int = 60, +) -> None: + """Runs ablation study comparing augmentation groups. + + Evaluates three conditions across all LNSO folds: + - full: all augmentations (EDA-specific + generic) + - generic_only: Gaussian noise, temporal cutout, amplitude scaling + - eda_specific_only: tonic/phasic extraction, loose sensor artifact + + Args: + all_samples: All windowed samples from WESADDataset. + pretrain_epochs: Contrastive pre-training epochs. + finetune_epochs: Fine-tuning epochs. + label_fraction: Fraction of labeled training data. + device: Torch device. + output_dir: Output directory for checkpoints. + window_size: EDA window size in samples. + """ + groups = ["full", "generic_only", "eda_specific_only"] + results = {g: [] for g in groups} + + for group in groups: + logger.info(f"\n{'='*60}") + logger.info(f"Augmentation group: {group}") + logger.info(f"{'='*60}") + for fold_idx, test_subjects in enumerate(LNSO_FOLDS): + bal_acc = run_fold( + fold_idx=fold_idx, + test_subjects=test_subjects, + all_samples=all_samples, + augmentation_group=group, + pretrain_epochs=pretrain_epochs, + finetune_epochs=finetune_epochs, + label_fraction=label_fraction, + device=device, + output_dir=output_dir, + window_size=window_size, + ) + results[group].append(bal_acc) + + # Print results table + logger.info("\n" + "="*60) + logger.info("ABLATION RESULTS") + logger.info("="*60) + logger.info(f"{'Augmentation Group':<25} {'Mean Bal Acc':>12} {'Std':>8}") + logger.info("-"*60) + for group in groups: + scores = results[group] + logger.info( + f"{group:<25} {np.mean(scores):>12.4f} {np.std(scores):>8.4f}" + ) + logger.info("="*60) + + +def main(): + parser = argparse.ArgumentParser( + description="Contrastive EDA pre-training and stress detection on WESAD" + ) + parser.add_argument( + "--data_root", type=str, required=True, + help="Path to WESAD dataset root directory" + ) + parser.add_argument( + "--output_dir", type=str, default="./outputs", + help="Directory to save encoder checkpoints" + ) + parser.add_argument( + "--augmentation_group", type=str, default="full", + choices=["full", "generic_only", "eda_specific_only", "ablation"], + help="Augmentation group to use. Use 'ablation' to run full ablation study." + ) + parser.add_argument( + "--pretrain_epochs", type=int, default=50, + help="Number of contrastive pre-training epochs" + ) + parser.add_argument( + "--finetune_epochs", type=int, default=20, + help="Number of supervised fine-tuning epochs" + ) + parser.add_argument( + "--label_fraction", type=float, default=0.01, + help="Fraction of labeled training data to use (default: 0.01 = 1%%)" + ) + parser.add_argument( + "--window_size", type=int, default=60, + help="EDA window size in samples (default: 60 = 15 seconds at 4Hz)" + ) + parser.add_argument( + "--seed", type=int, default=42, + help="Random seed for reproducibility" + ) + args = parser.parse_args() + + # Reproducibility + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Load dataset + logger.info(f"Loading WESAD from {args.data_root}") + dataset = WESADDataset( + root=args.data_root, + window_size=args.window_size, + step_size=10, + label_map={1: 0, 2: 1}, + ) + logger.info(f"Total windows: {len(dataset)}") + + if args.augmentation_group == "ablation": + run_ablation( + all_samples=dataset.samples, + pretrain_epochs=args.pretrain_epochs, + finetune_epochs=args.finetune_epochs, + label_fraction=args.label_fraction, + device=device, + output_dir=args.output_dir, + window_size=args.window_size, + ) + else: + # Single augmentation group across all folds + fold_scores = [] + for fold_idx, test_subjects in enumerate(LNSO_FOLDS): + bal_acc = run_fold( + fold_idx=fold_idx, + test_subjects=test_subjects, + all_samples=dataset.samples, + augmentation_group=args.augmentation_group, + pretrain_epochs=args.pretrain_epochs, + finetune_epochs=args.finetune_epochs, + label_fraction=args.label_fraction, + device=device, + output_dir=args.output_dir, + window_size=args.window_size, + ) + fold_scores.append(bal_acc) + + logger.info(f"\nMean balanced accuracy: {np.mean(fold_scores):.4f}") + logger.info(f"Std: {np.std(fold_scores):.4f}") + + +if __name__ == "__main__": + main() \ No newline at end of file From 44da4341da0295d724abcf55b252320d24ec728d Mon Sep 17 00:00:00 2001 From: msaunders804 Date: Sun, 12 Apr 2026 14:01:03 -0500 Subject: [PATCH 6/6] Add WESAD contrastive EDA pipeline: dataset, task, model, tests, and example - pyhealth/datasets/wesad.py: WESAD dataset class with EDA windowing - pyhealth/tasks/stress_detection.py: Stress detection task with LNSO splits - pyhealth/models/contrastive_eda.py: SimCLR contrastive encoder with NT-Xent loss and EDA augmentations - examples/wesad_stress_detection_contrastive_eda.py: Full pipeline with augmentation ablation - tests: 35 tests covering dataset, task, and model Reproduces Matton et al. CHIL 2023. --- pyhealth/tasks/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..22c96404e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .stress_detection import StressDetectionDataset \ No newline at end of file