From 32b251504b308cf6a06f010db9d1a7b804b718d6 Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 23 Feb 2026 01:21:52 -0500 Subject: [PATCH 1/2] Draft Bottleneck Transformer model --- pyhealth/models/__init__.py | 4 + pyhealth/models/bottleneck_transformer.py | 432 ++++++++++++++++++++++ 2 files changed, 436 insertions(+) create mode 100644 pyhealth/models/bottleneck_transformer.py diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..770cb628b 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,6 +1,10 @@ from .adacare import AdaCare, AdaCareLayer from .agent import Agent, AgentLayer from .base_model import BaseModel +from .bottleneck_transformer import ( + BottleneckTransformer, + MultimodalBottleneckTransformerEncoder +) from .biot import BIOT from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer diff --git a/pyhealth/models/bottleneck_transformer.py b/pyhealth/models/bottleneck_transformer.py new file mode 100644 index 000000000..a0dfb0d24 --- /dev/null +++ b/pyhealth/models/bottleneck_transformer.py @@ -0,0 +1,432 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel +from pyhealth.models.embedding import EmbeddingModel + + +class MultimodalBottleneckTransformerEncoder(nn.Module): + """ + Generalized Bottleneck Transformer Encoder for N modalities. + Based on "Attention Bottlenecks for Multimodal Fusion" (Nagrani et al., NeurIPS 2021). + """ + + def __init__( + self, + n_modality: int, + bottlenecks_n: int, + fusion_startidx: int, + n_layers: int, + n_head: int, + d_model: int, + d_ff: int, + dropout: float = 0.1, + ): + super(MultimodalBottleneckTransformerEncoder, self).__init__() + + self.n_modality = n_modality + self.fusion_startidx = fusion_startidx + self.n_layers = n_layers + self.n_fusion_layers = n_layers - fusion_startidx + self.n_prefusion = fusion_startidx + self.d_model = d_model + self.n_bottlenecks = bottlenecks_n + + # Shared Bottleneck Tokens + self.bottlenecks = nn.Parameter(torch.randn(1, bottlenecks_n, d_model)) + + # Prefusion Stacks: independent layers per modality + self.prefusion_stacks = nn.ModuleList([ + nn.ModuleList([ + nn.TransformerEncoderLayer( + d_model=d_model, + nhead=n_head, + dim_feedforward=d_ff, + dropout=dropout, + batch_first=True + ) for _ in range(n_modality) + ]) for _ in range(self.n_prefusion) + ]) + + # Fusion Stacks: processes [bottleneck_tokens || modality_tokens] + self.fusion_stacks = nn.ModuleList([ + nn.ModuleList([ + nn.TransformerEncoderLayer( + d_model=d_model, + nhead=n_head, + dim_feedforward=d_ff, + dropout=dropout, + batch_first=True + ) for _ in range(n_modality) + ]) for _ in range(self.n_fusion_layers) + ]) + + def forward_prefusion(self, enc_inputs: List[torch.Tensor], masks: List[torch.Tensor]) -> List[torch.Tensor]: + for enc_layers in self.prefusion_stacks: + enc_outputs = [] + for modal_idx, enc_layer in enumerate(enc_layers): + # Apply mask to padding tokens (src_key_padding_mask requires True for ignoring) + # True in mask = invalid/padding + enc_out = enc_layer(enc_inputs[modal_idx], src_key_padding_mask=~masks[modal_idx] if masks[modal_idx] is not None else None) + enc_outputs.append(enc_out) + enc_inputs = enc_outputs + return enc_inputs + + def forward_fusion(self, enc_inputs: List[torch.Tensor], masks: List[torch.Tensor], bottleneck_tokens: torch.Tensor, valid_modalities: List[torch.Tensor]) -> List[torch.Tensor]: + # valid_modalities: [B] list of boolean/float tensors indicating if modality is present + batch_size = enc_inputs[0].size(0) + + for modality_encoders in self.fusion_stacks: + enc_outputs = [] + bottleneck_tokens_modality_sum = torch.zeros_like(bottleneck_tokens) + sum_of_modalities = torch.zeros(batch_size, 1, 1, device=bottleneck_tokens.device) + + for idx, enc_layer in enumerate(modality_encoders): + # Concatenate bottleneck tokens with modality tokens + # bottleneck_tokens: [B, num_bottlenecks, d_model] + # enc_inputs[idx]: [B, seq_len, d_model] + fused_input = torch.cat([bottleneck_tokens, enc_inputs[idx]], dim=1) + + # Padding mask for bottleneck tokens is always False (i.e. valid) + # [B, num_bottlenecks] of False + b_mask = torch.zeros(batch_size, self.n_bottlenecks, dtype=torch.bool, device=fused_input.device) + + # Modality padding mask + m_mask = ~masks[idx] if masks[idx] is not None else torch.zeros(batch_size, enc_inputs[idx].size(1), dtype=torch.bool, device=fused_input.device) + + combined_mask = torch.cat([b_mask, m_mask], dim=1) + + # Pass through the layer + enc_out = enc_layer(fused_input, src_key_padding_mask=combined_mask) + + # The output consists of processed bottleneck tokens and modality tokens + # [B, num_bottlenecks, d_model] and [B, seq_len, d_model] + bottleneck_hidden_tokens = enc_out[:, :self.n_bottlenecks, :] + modality_hidden_tokens = enc_out[:, self.n_bottlenecks:, :] + enc_outputs.append(modality_hidden_tokens) + + # Average updated bottlenecks from valid modalities + modality_is_valid = valid_modalities[idx].view(batch_size, 1, 1) + bottleneck_tokens_modality_sum += bottleneck_hidden_tokens * modality_is_valid + sum_of_modalities += modality_is_valid + + # Prevent division by zero if all modalities are missing + # If sum_of_modalities is 0, just pass zeros (or keep previous bottleneck_tokens) + # sum_of_modalities = torch.clamp(sum_of_modalities, min=1.0) + avg_divisor = sum_of_modalities.clone() + avg_divisor[avg_divisor == 0] = 1.0 + + bottleneck_tokens = bottleneck_tokens_modality_sum / avg_divisor + enc_inputs = enc_outputs + + return enc_inputs + + def forward(self, enc_inputs: List[torch.Tensor], masks: List[torch.Tensor]) -> List[torch.Tensor]: + batch_size = enc_inputs[0].size(0) + + # Determine if a modality is valid for each instance in the batch + # A modality is valid if it has at least one True in its mask + valid_modalities = [] + for mask, inp in zip(masks, enc_inputs): + if mask is not None: + # [B] - True if there's any valid token (1/True) + valid = mask.any(dim=1).float() + else: + valid = torch.ones(batch_size, device=inp.device) + valid_modalities.append(valid) + + bottleneck_tokens = self.bottlenecks.expand(batch_size, -1, -1) + + enc_inputs = self.forward_prefusion(enc_inputs, masks) + enc_inputs = self.forward_fusion(enc_inputs, masks, bottleneck_tokens, valid_modalities) + + return enc_inputs + + +class BottleneckTransformer(BaseModel): + """Bottleneck Transformer model for PyHealth datasets. + + This model employs a unified multimodal approach by embedding diverse + feature streams using :class:`EmbeddingModel` and fusing them with + the Attention Bottleneck mechanism. + + Each modality first prepends a learnable [CLS] token and is processed by + independent `prefusion` transformer layers. Then, they are processed by + fusion transformer layers with shared bottleneck tokens. The [CLS] token + of each modality is extracted, averaged, and fed to the classification head. + + Args: + dataset (SampleDataset): dataset providing processed inputs. + embedding_dim (int): shared embedding dimension. + bottlenecks_n (int): number of shared bottleneck tokens. + fusion_startidx (int): the layer index at which bottleneck fusion starts. + num_layers (int): total number of transformer layers (prefusion + fusion). + heads (int): number of attention heads per transformer block. + dropout (float): dropout rate applied inside transformer blocks. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": ["A", "B", "C"], + ... "procedures": ["X", "Y"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-0", + ... "conditions": ["D"], + ... "procedures": ["Z", "Y"], + ... "label": 0, + ... }, + ... ] + >>> input_schema = {"conditions": "sequence", "procedures": "sequence"} + >>> output_schema = {"label": "binary"} + >>> dataset = create_sample_dataset( + ... samples, + ... input_schema, + ... output_schema, + ... dataset_name="demo", + ... ) + >>> model = BottleneckTransformer(dataset=dataset, num_layers=3, fusion_startidx=1, bottlenecks_n=4) + >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> batch = next(iter(loader)) + >>> output = model(**batch) + >>> sorted(output.keys()) + ['logit', 'loss', 'y_prob', 'y_true'] + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + bottlenecks_n: int = 4, + fusion_startidx: int = 1, + num_layers: int = 3, + heads: int = 4, + dropout: float = 0.5, + ): + super().__init__(dataset=dataset) + self.embedding_dim = embedding_dim + self.bottlenecks_n = bottlenecks_n + self.fusion_startidx = fusion_startidx + self.num_layers = num_layers + self.heads = heads + self.dropout = dropout + + assert ( + len(self.label_keys) == 1 + ), "Only one label key is supported if BottleneckTransformer is initialized" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + + self.n_modality = len(self.feature_keys) + + # Classification tokens for each modality + self.cls_token_per_modality = nn.ParameterList([ + nn.Parameter(torch.randn(1, 1, embedding_dim)) for _ in range(self.n_modality) + ]) + + self.encoder = MultimodalBottleneckTransformerEncoder( + n_modality=self.n_modality, + bottlenecks_n=bottlenecks_n, + fusion_startidx=fusion_startidx, + n_layers=num_layers, + n_head=heads, + d_model=embedding_dim, + d_ff=embedding_dim * 4, + dropout=dropout + ) + + output_size = self.get_output_size() + # Outputs of each modality's CLS token are averaged, not concatenated + self.fc = nn.Linear(embedding_dim, output_size) + + @staticmethod + def _pool_embedding(x: torch.Tensor) -> torch.Tensor: + if x.dim() == 4: + x = x.sum(dim=2) + if x.dim() == 2: + x = x.unsqueeze(1) + return x + + @staticmethod + def _mask_from_embeddings(x: torch.Tensor) -> torch.Tensor: + mask = torch.any(torch.abs(x) > 0, dim=-1) + if mask.dim() == 1: + mask = mask.unsqueeze(1) + invalid_rows = ~mask.any(dim=1) + if invalid_rows.any(): + mask[invalid_rows, 0] = True + return mask.bool() + + def forward( + self, + **kwargs: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + ) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: keyword arguments for the model. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the final loss. + y_prob: a tensor of predicted probabilities. + y_true: a tensor representing the true labels. + logit: the raw logits before activation. + """ + enc_inputs = [] + masks = [] + + for idx, feature_key in enumerate(self.feature_keys): + feature = kwargs[feature_key] + + if isinstance(feature, torch.Tensor): + feature = (feature,) + + schema = self.dataset.input_processors[feature_key].schema() + + value = feature[schema.index("value")] if "value" in schema else None + mask = feature[schema.index("mask")] if "mask" in schema else None + + if len(feature) == len(schema) + 1 and mask is None: + mask = feature[-1] + + if value is None: + raise ValueError( + f"Feature '{feature_key}' must contain 'value' " + f"in the schema." + ) + else: + value = value.to(self.device) + + if mask is not None: + mask = mask.to(self.device) + value = self.embedding_model({feature_key: value}, masks={feature_key: mask})[feature_key] + else: + value = self.embedding_model({feature_key: value})[feature_key] + + value = self._pool_embedding(value) + + if mask is not None: + mask = mask.bool() + if mask.dim() == value.dim(): + mask = mask.any(dim=-1) + else: + mask = self._mask_from_embeddings(value) + + # Prepend Modality CLS token + batch_size = value.size(0) + cls_token = self.cls_token_per_modality[idx].expand(batch_size, -1, -1) + value = torch.cat([cls_token, value], dim=1) + + # Update mask for CLS token (always valid) + cls_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=value.device) + mask = torch.cat([cls_mask, mask], dim=1) + + enc_inputs.append(value) + masks.append(mask) + + # Pass through Bottleneck Transformer Encoder + enc_outputs = self.encoder(enc_inputs, masks) + + # Extract CLS tokens + cls_tokens = [out[:, 0, :].unsqueeze(1) for out in enc_outputs] + cls_tokens = torch.cat(cls_tokens, dim=1) # [B, n_modality, embedding_dim] + + # Average CLS tokens across valid modalities + b_size = cls_tokens.size(0) + valid_modalities = [] + for mask in masks: + # We check if there's any valid token aside from the CLS token (index 0) + if mask.size(1) > 1: + valid = mask[:, 1:].any(dim=1).float() + else: + valid = mask[:, 0].float() # fallback + valid_modalities.append(valid.view(b_size, 1, 1)) + + valid_modality_tensor = torch.cat(valid_modalities, dim=1) # [B, n_modality, 1] + + # Apply valid mask + masked_cls = cls_tokens * valid_modality_tensor + sum_valid = valid_modality_tensor.sum(dim=1) # [B, 1] + + # Avoid division by zero + sum_valid[sum_valid == 0] = 1.0 + patient_emb = masked_cls.sum(dim=1) / sum_valid # [B, embedding_dim] + + logits = self.fc(patient_emb) + y_prob = self.prepare_y_prob(logits) + + results = { + "logit": logits, + "y_prob": y_prob, + } + + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + return results + +if __name__ == "__main__": + from pyhealth.datasets import create_sample_dataset, get_dataloader + + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["A", "B", "C"], + "procedures": ["X", "Y"], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "conditions": ["D"], + "procedures": ["Z", "Y"], + "label": 0, + }, + ] + + input_schema = { + "conditions": "sequence", + "procedures": "sequence", + } + output_schema = {"label": "binary"} + + dataset = create_sample_dataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test", + ) + + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + + model = BottleneckTransformer( + dataset=dataset, + embedding_dim=64, + bottlenecks_n=2, + fusion_startidx=1, + num_layers=3, + heads=2 + ) + + data_batch = next(iter(train_loader)) + + result = model(**data_batch) + print(result) + + result["loss"].backward() + print("Test completed successfully.") From 89b32224dedbd29c34b6cfc4bc603c49f7e53077 Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 23 Feb 2026 01:40:11 -0500 Subject: [PATCH 2/2] Add tests, docs, and examples for Bottleneck Transformer --- docs/api/models.rst | 1 + .../pyhealth.models.BottleneckTransformer.rst | 68 ++++++ .../bottleneck_transformer_tutorial.ipynb | 199 ++++++++++++++++++ tests/core/test_bottleneck_transformer.py | 108 ++++++++++ 4 files changed, 376 insertions(+) create mode 100644 docs/api/models/pyhealth.models.BottleneckTransformer.rst create mode 100644 examples/bottleneck_transformer_tutorial.ipynb create mode 100644 tests/core/test_bottleneck_transformer.py diff --git a/docs/api/models.rst b/docs/api/models.rst index 2621b6a2a..2da263ff3 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -14,6 +14,7 @@ We implement the following models for supporting multiple healthcare predictive models/pyhealth.models.RNN models/pyhealth.models.GNN models/pyhealth.models.Transformer + models/pyhealth.models.BottleneckTransformer models/pyhealth.models.TransformersModel models/pyhealth.models.RETAIN models/pyhealth.models.GAMENet diff --git a/docs/api/models/pyhealth.models.BottleneckTransformer.rst b/docs/api/models/pyhealth.models.BottleneckTransformer.rst new file mode 100644 index 000000000..3b3824973 --- /dev/null +++ b/docs/api/models/pyhealth.models.BottleneckTransformer.rst @@ -0,0 +1,68 @@ +BottleneckTransformer +--------------------- + +.. autoclass:: pyhealth.models.BottleneckTransformer + :members: + :undoc-members: + :show-inheritance: + +**Overview** + +``BottleneckTransformer`` is a multimodal architecture based on *Attention Bottlenecks for Multimodal Fusion* (NeurIPS 2021). It uses shared bottleneck tokens to fuse representations across an arbitrary number of modalities—restricting cross-modal attention to improve compute efficiency and prevent noise from dominating specific modalities. + +**Input / Output** + +- **Input:** ``dict[str, Tensor]`` + — kwargs dictionary mapping from PyHealth's input_schema fields to tensor data (sequences, vectors, or multimodal tokens). +- **Output:** ``dict`` with keys: + + - ``"y_prob"`` — ``(B, num_classes)`` predicted probabilities + - ``"y_true"`` — ``(B, num_classes)`` true labels (if provided) + - ``"logit"`` — ``(B, num_classes)`` raw logits + - ``"loss"`` — scalar tensor (if true labels are provided) + +**Key Features** + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Feature + - Description + * - **Bottleneck Fusion** + - Modalities interact solely through shared bottleneck tokens + * - **Dynamic Modality Support** + - Automatically adapts its encoder branches depending on ``dataset.input_schema`` length + * - **Pre-fusion Encoding** + - Intra-modal feature processing occurs independently up to ``fusion_startidx`` layer + * - **Token Masking Support** + - Safely processes ragged inputs dynamically generated by sequences with a ``mask`` + +**Example Usage** + +.. code-block:: python + + from pyhealth.datasets import create_sample_dataset + from pyhealth.models import BottleneckTransformer + + # Build multimodal dataset + dataset = create_sample_dataset( + samples=samples, + input_schema={"conditions": "sequence", "procedures": "sequence"}, + output_schema={"label": "binary"} + ) + + # Initialize model + model = BottleneckTransformer( + dataset=dataset, + embedding_dim=128, + bottlenecks_n=4, # Number of bottleneck tokens + fusion_startidx=2, # Layer index to begin cross-modal fusion + num_layers=4, + heads=4 + ) + +.. autoclass:: pyhealth.models.bottleneck_transformer.MultimodalBottleneckTransformerEncoder + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/bottleneck_transformer_tutorial.ipynb b/examples/bottleneck_transformer_tutorial.ipynb new file mode 100644 index 000000000..2885be943 --- /dev/null +++ b/examples/bottleneck_transformer_tutorial.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro_md", + "metadata": {}, + "source": [ + "# Bottleneck Transformer Tutorial\n", + "\n", + "This notebook demonstrates how to use the `BottleneckTransformer` model for multimodal healthcare data fusion in PyHealth.\n", + "\n", + "**Overview:**\n", + "- Initialize BottleneckTransformer with multi-modality data\n", + "- Demonstrate modality-specific pre-fusion vs multimodal bottleneck fusion\n", + "- Highlight architecture hyperparameters `bottlenecks_n` and `fusion_startidx`\n", + "- Inspect forward passes and probability mappings" + ] + }, + { + "cell_type": "markdown", + "id": "env_md", + "metadata": {}, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "env_code", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "data_md", + "metadata": {}, + "source": [ + "## 2. Data Preparation\n", + "We use PyHealth's `create_sample_dataset` to generate a lightweight multimodal dataset. You can substitute this with `MIMIC3Dataset`, `MIMIC4Dataset` or `OMOPDataset` for real-world scenarios." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "data_code", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import create_sample_dataset\n", + "\n", + "samples = [\n", + " {\n", + " \"patient_id\": \"patient-0\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"conditions\": [\"A\", \"B\", \"C\"],\n", + " \"procedures\": [\"X\", \"Y\"],\n", + " \"labs\": [1.0, 2.0, 3.0],\n", + " \"label\": 1,\n", + " },\n", + " {\n", + " \"patient_id\": \"patient-1\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"conditions\": [\"D\", \"E\"],\n", + " \"procedures\": [\"Y\"],\n", + " \"labs\": [4.0, 5.0, 6.0],\n", + " \"label\": 0,\n", + " },\n", + "]\n", + "\n", + "input_schema = {\n", + " \"conditions\": \"sequence\",\n", + " \"procedures\": \"sequence\",\n", + " \"labs\": \"tensor\",\n", + "}\n", + "output_schema = {\"label\": \"binary\"}\n", + "\n", + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema=input_schema,\n", + " output_schema=output_schema,\n", + " dataset_name=\"test\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "loader_md", + "metadata": {}, + "source": [ + "## 3. Dataloader Setup\n", + "We use PyHealth's automatic `get_dataloader` utility which converts the structured processed fields into batches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "loader_code", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import get_dataloader\n", + "\n", + "train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "id": "model_md", + "metadata": {}, + "source": [ + "## 4. Initialize Bottleneck Transformer\n", + "The model initializes modality-specific transformer paths and limits the dense attention flow to bottleneck tokens specifically. \n", + "\n", + "- `fusion_startidx` parameter decides which layer cross-attention over bottlenecks activates. Lower means earlier fusion.\n", + "- `bottlenecks_n` regulates how many tokens represent the capacity of the bottleneck." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "model_code", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.models import BottleneckTransformer\n", + "\n", + "model = BottleneckTransformer(\n", + " dataset=dataset,\n", + " embedding_dim=128,\n", + " bottlenecks_n=4,\n", + " fusion_startidx=1,\n", + " num_layers=3,\n", + " heads=4\n", + ").to(device)\n", + "\n", + "print(\"Model modalities:\", model.feature_keys)\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "id": "forward_md", + "metadata": {}, + "source": [ + "## 5. Forward Pass\n", + "Perform a simple mapping to inspect outputs. PyHealth models produce unified dicts returning `loss`, probability spaces `y_prob`, and predictions `logit`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "forward_code", + "metadata": {}, + "outputs": [], + "source": [ + "data_batch = next(iter(train_loader))\n", + "outputs = model(**data_batch)\n", + "\n", + "for k, v in outputs.items():\n", + " try:\n", + " print(f\"{k}: {v.shape}\")\n", + " except AttributeError:\n", + " print(f\"{k}: {v}\")\n", + "\n", + "print(\"\\nForward pass successful!\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tests/core/test_bottleneck_transformer.py b/tests/core/test_bottleneck_transformer.py new file mode 100644 index 000000000..9adbe6432 --- /dev/null +++ b/tests/core/test_bottleneck_transformer.py @@ -0,0 +1,108 @@ +import unittest +from typing import Dict, Type, Union + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import BottleneckTransformer +from pyhealth.processors.base_processor import FeatureProcessor + + +class TestBottleneckTransformer(unittest.TestCase): + """Test cases for the Bottleneck Transformer model.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "diagnoses": ["A", "B", "C"], + "procedures": ["X", "Y"], + "labs": [1.0, 2.0, 3.0], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "diagnoses": ["D", "E"], + "procedures": ["Y"], + "labs": [4.0, 5.0, 6.0], + "label": 0, + }, + ] + + self.input_schema: Dict[str, Union[str, Type[FeatureProcessor]]] = { + "diagnoses": "sequence", + "procedures": "sequence", + "labs": "tensor", + } + self.output_schema: Dict[str, Union[str, Type[FeatureProcessor]]] = { + "label": "binary" + } + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = BottleneckTransformer( + dataset=self.dataset, + embedding_dim=128, + bottlenecks_n=2, + fusion_startidx=1, + num_layers=2, + heads=2 + ) + + def test_model_initialization(self): + """Test that the BottleneckTransformer model initializes correctly.""" + self.assertIsInstance(self.model, BottleneckTransformer) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(self.model.bottlenecks_n, 2) + self.assertEqual(self.model.num_layers, 2) + self.assertEqual(len(self.model.feature_keys), 3) + self.assertIn("diagnoses", self.model.feature_keys) + self.assertIn("procedures", self.model.feature_keys) + self.assertIn("labs", self.model.feature_keys) + self.assertEqual(self.model.label_key, "label") + + def test_model_forward(self): + """Test that the BottleneckTransformer forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["y_prob"].shape[1], 1) + self.assertEqual(ret["y_true"].shape[1], 1) + self.assertEqual(ret["logit"].shape[1], 1) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the BottleneckTransformer backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + +if __name__ == "__main__": + unittest.main()