Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,4 @@ API Reference
models/pyhealth.models.TextEmbedding
models/pyhealth.models.BIOT
models/pyhealth.models.unified_multimodal_embedding_docs
models/pyhealth.models.Wav2Sleep
7 changes: 7 additions & 0 deletions docs/api/models/pyhealth.models.wav2sleep.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.models.Wav2sleep
=========================

.. automodule:: pyhealth.models.Wav2Sleep
:members:
:undoc-members:
:show-inheritance:
54 changes: 54 additions & 0 deletions examples/mimic4_sleep_staging_wav2sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Example script for Sleep Stage Classification using Wav2Sleep on MIMIC-IV dataset.
This script demonstrates the model's robustness through an Ablation Study
on missing modalities (Stochastic Masking), adapted for MIMIC-IV clinical signals.
"""

import torch
from pyhealth.models import Wav2Sleep

def run_example():
print("--- PyHealth Example: MIMIC-IV Sleep Staging with Wav2Sleep ---")

# 1. Setup mock data (Adapted for MIMIC-IV: ECG + Respiratory/PPG)
# batch_size=2, sequence_length=5 epochs, signal_length=3000
batch_size, seq_len, signal_len = 2, 5, 3000

data = {
"ecg": torch.randn(batch_size, seq_len, signal_len),
"resp": torch.randn(batch_size, seq_len, signal_len),
"label": torch.randint(0, 5, (batch_size, seq_len))
}

# 2. Initialize Wav2Sleep
model = Wav2Sleep(
dataset=None,
feature_keys=["ecg", "resp"],
label_key="label",
mode="multiclass",
embedding_dim=128,
mask_prob={"ecg": 0.5, "resp": 0.5}
)

# 3. Ablation Study: Clinical Signal Loss
print("\n[Ablation] Scenario: Respiratory sensor noise/loss in MIMIC-IV")

data_missing = {
"ecg": data["ecg"],
"resp": torch.zeros_like(data["resp"]),
"label": data["label"]
}

model.eval()
with torch.no_grad():
output = model(**data_missing)

print(f"Inference Successful!")
print(f"Loss with missing modality: {output['loss']:.4f}")
print(f"Output probability shape: {output['y_prob'].shape} (5 Sleep Stages)")

print("\n[Clinical Value]: The model maintains diagnostic capability "
"even with incomplete bedside monitor data.")

if __name__ == "__main__":
run_example()
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
from .sdoh import SdohClassifier
from .medlink import MedLink
from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding
from .wav2sleep import Wav2Sleep
159 changes: 159 additions & 0 deletions pyhealth/models/wav2sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
import torch.nn as nn
from typing import Dict, List
from pyhealth.models import BaseModel


class ResBlock(nn.Module):
"""Residual Block used in Signal Encoders."""

def __init__(self, in_channels, out_channels, kernel_size=3):
super(ResBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size,
padding=kernel_size // 2),
nn.GELU(),
nn.Conv1d(out_channels, out_channels, kernel_size,
padding=kernel_size // 2),
nn.GELU(),
nn.Conv1d(out_channels, out_channels, kernel_size,
padding=kernel_size // 2),
)
self.shortcut = (
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
)
self.pool = nn.MaxPool1d(2)
self.gelu = nn.GELU()

def forward(self, x):
res = self.shortcut(x)
x = self.conv(x)
x = self.gelu(x + res)
return self.pool(x)


class Wav2Sleep(BaseModel):
"""Wav2Sleep: A Unified Multi-Modal Approach to Sleep Stage Classification.

Paper: Carter, J. F.; and Tarassenko, L. 2024. wav2sleep: A Unified
Multi-Modal Approach to Sleep Stage Classification from Physiological Signals.

The model consists of modality-specific CNN encoders, a transformer-based
epoch mixer with a [CLS] token, and a dilated CNN sequence mixer.
"""

def __init__(
self,
dataset,
feature_keys: List[str],
label_key: str,
mode: str,
embedding_dim: int = 128,
nhead: int = 8,
num_layers: int = 2,
mask_prob: Dict[str, float] = None,
**kwargs,
):
super(Wav2Sleep, self).__init__(
dataset=dataset,
**kwargs
)

self.feature_keys = feature_keys
self.label_key = label_key
self.mode = mode
self.embedding_dim = embedding_dim

if dataset is not None and hasattr(dataset, "label_schema"):
self.total_num_classes = 5
else:
self.total_num_classes = 5

# [span_2](start_span)Default masking probabilities from paper[span_2]
# (end_span)
self.mask_probs = mask_prob or {
"ecg": 0.5, "ppg": 0.1, "abd": 0.7, "thx": 0.7
}

# 1. [span_3](start_span)[span_4](start_span)Signal Encoders: Modality
# specific CNNs[span_3](end_span)[span_4](end_span)
self.feature_encoders = nn.ModuleDict()
for key in feature_keys:
# [span_5](start_span)[span_6](start_span)Paper uses 6-8 layers depending
# on sampling rate k[span_5](end_span)[span_6](end_span)
layers = [ResBlock(1, 16)]
layers += [ResBlock(16 * (2 ** i), 16 * (2 ** (i + 1))) for i in range(3)]
layers.append(nn.AdaptiveAvgPool1d(1))
self.feature_encoders[key] = nn.Sequential(*layers)

# 2. [span_7](start_span)Epoch Mixer: Transformer with [CLS] token[span_7]
# (end_span)
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim, nhead=nhead, dim_feedforward=512,
batch_first=True, activation="gelu"
)
self.epoch_mixer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

# 3. [span_8](start_span)[span_9](start_span)Sequence Mixer: Dilated
# Convolutions[span_8](end_span)[span_9](end_span)
# [span_10](start_span)Two blocks with dilations (1, 2, 4, 8, 16, 32)[span_10]
# (end_span)
self.sequence_mixer = nn.Sequential(
nn.Conv1d(embedding_dim, embedding_dim, 7, padding=6, dilation=2),
nn.GELU(),
nn.Conv1d(embedding_dim, embedding_dim, 7, padding=12, dilation=4),
nn.GELU(),
)
self.fc = nn.Linear(embedding_dim, self.total_num_classes)

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward pass with stochastic masking and multi-modal fusion."""
batch_size = kwargs[self.feature_keys[0]].shape[0]
seq_len = kwargs[self.feature_keys[0]].shape[1] # T=1200

# List to store features [batch*seq_len, 1, embedding_dim]
all_modality_features = []

for key in self.feature_keys:
x = kwargs[key].view(-1, 1, kwargs[key].shape[-1]) # [B*T, 1, L]
feat = self.feature_encoders[key](x).view(batch_size, seq_len, -1)

# [span_11](start_span)Stochastic Masking during training[span_11]
# (end_span)
if self.training:
p = self.mask_probs.get(key.lower(), 0.5)
mask = (torch.rand(batch_size, 1, 1, device=feat.device) > p).float()
feat = feat * mask

all_modality_features.append(feat.unsqueeze(2)) # [B, T, 1, D]

# Combine modalities for Epoch Mixer
# x: [B*T, num_modalities, D]
x = torch.cat(all_modality_features, dim=2).view(-1, len(self.feature_keys)
, 128)

# Add CLS token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # [B*T, M+1, D]

# Epoch Fusion
x = self.epoch_mixer(x)
z_t = x[:, 0, :].view(batch_size, seq_len, -1) # Extract CLS [B, T, D]

# [span_12](start_span)Sequence Mixing: Capture temporal dependencies[span_12]
# (end_span)
z_t = z_t.transpose(1, 2) # [B, D, T]
z_seq = self.sequence_mixer(z_t).transpose(1, 2) # [B, T, D]

logits = self.fc(z_seq)

# PyHealth expectation: return loss and probabilities
return {
"y_prob": torch.softmax(logits, dim=-1),
"y_true": kwargs[self.label_key],
"loss": nn.CrossEntropyLoss()(logits.view(-1, self.total_num_classes),
kwargs[self.label_key].view(-1))
}
66 changes: 66 additions & 0 deletions tests/core/test_wav2sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Unit tests for Wav2Sleep model.
Requirement: Fast, performant, and uses synthetic data.
"""
import unittest
import torch
from pyhealth.models import Wav2Sleep


class TestWav2Sleep(unittest.TestCase):
def setUp(self):
class MockDataset:
def __init__(self):
self.input_schema = {
"ecg": {"type": float},
"ppg": {"type": float}
}

self.output_schema = {
"label": {"type": int}
}

self.dataset = MockDataset()
self.feature_keys = ["ecg", "ppg"]
self.label_key = "label"

self.model = Wav2Sleep(
dataset=self.dataset,
feature_keys=self.feature_keys,
label_key=self.label_key,
mode="multiclass",
embedding_dim=128,
nhead=4,
num_layers=1
)

self.model.total_num_classes = 5

def test_forward_pass(self):
"""Test if the forward pass works and returns correct shapes."""
batch_size = 2
seq_len = 10 # number of epochs
signal_len = 100 # simplified signal length

# Create synthetic tensors
data = {
"ecg": torch.randn(batch_size, seq_len, signal_len),
"ppg": torch.randn(batch_size, seq_len, signal_len),
"label": torch.randint(0, 5, (batch_size, seq_len))
}

output = self.model(**data)

# Check keys
self.assertIn("loss", output)
self.assertIn("y_prob", output)

# Check output shape [B, T, C]
self.assertEqual(output["y_prob"].shape, (batch_size, seq_len, 5))

# Check if loss is a scalar
self.assertEqual(output["loss"].dim(), 0)


if __name__ == "__main__":
unittest.main()