Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ We implement the following models for supporting multiple healthcare predictive
models/pyhealth.models.RNN
models/pyhealth.models.GNN
models/pyhealth.models.Transformer
models/pyhealth.models.BottleneckTransformer
models/pyhealth.models.TransformersModel
models/pyhealth.models.RETAIN
models/pyhealth.models.GAMENet
Expand Down
68 changes: 68 additions & 0 deletions docs/api/models/pyhealth.models.BottleneckTransformer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
BottleneckTransformer
---------------------

.. autoclass:: pyhealth.models.BottleneckTransformer
:members:
:undoc-members:
:show-inheritance:

**Overview**

``BottleneckTransformer`` is a multimodal architecture based on *Attention Bottlenecks for Multimodal Fusion* (NeurIPS 2021). It uses shared bottleneck tokens to fuse representations across an arbitrary number of modalities—restricting cross-modal attention to improve compute efficiency and prevent noise from dominating specific modalities.

**Input / Output**

- **Input:** ``dict[str, Tensor]``
— kwargs dictionary mapping from PyHealth's input_schema fields to tensor data (sequences, vectors, or multimodal tokens).
- **Output:** ``dict`` with keys:

- ``"y_prob"`` — ``(B, num_classes)`` predicted probabilities
- ``"y_true"`` — ``(B, num_classes)`` true labels (if provided)
- ``"logit"`` — ``(B, num_classes)`` raw logits
- ``"loss"`` — scalar tensor (if true labels are provided)

**Key Features**

.. list-table::
:header-rows: 1
:widths: 20 80

* - Feature
- Description
* - **Bottleneck Fusion**
- Modalities interact solely through shared bottleneck tokens
* - **Dynamic Modality Support**
- Automatically adapts its encoder branches depending on ``dataset.input_schema`` length
* - **Pre-fusion Encoding**
- Intra-modal feature processing occurs independently up to ``fusion_startidx`` layer
* - **Token Masking Support**
- Safely processes ragged inputs dynamically generated by sequences with a ``mask``

**Example Usage**

.. code-block:: python

from pyhealth.datasets import create_sample_dataset
from pyhealth.models import BottleneckTransformer

# Build multimodal dataset
dataset = create_sample_dataset(
samples=samples,
input_schema={"conditions": "sequence", "procedures": "sequence"},
output_schema={"label": "binary"}
)

# Initialize model
model = BottleneckTransformer(
dataset=dataset,
embedding_dim=128,
bottlenecks_n=4, # Number of bottleneck tokens
fusion_startidx=2, # Layer index to begin cross-modal fusion
num_layers=4,
heads=4
)

.. autoclass:: pyhealth.models.bottleneck_transformer.MultimodalBottleneckTransformerEncoder
:members:
:undoc-members:
:show-inheritance:
199 changes: 199 additions & 0 deletions examples/bottleneck_transformer_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "intro_md",
"metadata": {},
"source": [
"# Bottleneck Transformer Tutorial\n",
"\n",
"This notebook demonstrates how to use the `BottleneckTransformer` model for multimodal healthcare data fusion in PyHealth.\n",
"\n",
"**Overview:**\n",
"- Initialize BottleneckTransformer with multi-modality data\n",
"- Demonstrate modality-specific pre-fusion vs multimodal bottleneck fusion\n",
"- Highlight architecture hyperparameters `bottlenecks_n` and `fusion_startidx`\n",
"- Inspect forward passes and probability mappings"
]
},
{
"cell_type": "markdown",
"id": "env_md",
"metadata": {},
"source": [
"## 1. Environment Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "env_code",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Running on device: {device}\")"
]
},
{
"cell_type": "markdown",
"id": "data_md",
"metadata": {},
"source": [
"## 2. Data Preparation\n",
"We use PyHealth's `create_sample_dataset` to generate a lightweight multimodal dataset. You can substitute this with `MIMIC3Dataset`, `MIMIC4Dataset` or `OMOPDataset` for real-world scenarios."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "data_code",
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.datasets import create_sample_dataset\n",
"\n",
"samples = [\n",
" {\n",
" \"patient_id\": \"patient-0\",\n",
" \"visit_id\": \"visit-0\",\n",
" \"conditions\": [\"A\", \"B\", \"C\"],\n",
" \"procedures\": [\"X\", \"Y\"],\n",
" \"labs\": [1.0, 2.0, 3.0],\n",
" \"label\": 1,\n",
" },\n",
" {\n",
" \"patient_id\": \"patient-1\",\n",
" \"visit_id\": \"visit-0\",\n",
" \"conditions\": [\"D\", \"E\"],\n",
" \"procedures\": [\"Y\"],\n",
" \"labs\": [4.0, 5.0, 6.0],\n",
" \"label\": 0,\n",
" },\n",
"]\n",
"\n",
"input_schema = {\n",
" \"conditions\": \"sequence\",\n",
" \"procedures\": \"sequence\",\n",
" \"labs\": \"tensor\",\n",
"}\n",
"output_schema = {\"label\": \"binary\"}\n",
"\n",
"dataset = create_sample_dataset(\n",
" samples=samples,\n",
" input_schema=input_schema,\n",
" output_schema=output_schema,\n",
" dataset_name=\"test\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "loader_md",
"metadata": {},
"source": [
"## 3. Dataloader Setup\n",
"We use PyHealth's automatic `get_dataloader` utility which converts the structured processed fields into batches."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "loader_code",
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.datasets import get_dataloader\n",
"\n",
"train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)"
]
},
{
"cell_type": "markdown",
"id": "model_md",
"metadata": {},
"source": [
"## 4. Initialize Bottleneck Transformer\n",
"The model initializes modality-specific transformer paths and limits the dense attention flow to bottleneck tokens specifically. \n",
"\n",
"- `fusion_startidx` parameter decides which layer cross-attention over bottlenecks activates. Lower means earlier fusion.\n",
"- `bottlenecks_n` regulates how many tokens represent the capacity of the bottleneck."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "model_code",
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.models import BottleneckTransformer\n",
"\n",
"model = BottleneckTransformer(\n",
" dataset=dataset,\n",
" embedding_dim=128,\n",
" bottlenecks_n=4,\n",
" fusion_startidx=1,\n",
" num_layers=3,\n",
" heads=4\n",
").to(device)\n",
"\n",
"print(\"Model modalities:\", model.feature_keys)\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "forward_md",
"metadata": {},
"source": [
"## 5. Forward Pass\n",
"Perform a simple mapping to inspect outputs. PyHealth models produce unified dicts returning `loss`, probability spaces `y_prob`, and predictions `logit`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "forward_code",
"metadata": {},
"outputs": [],
"source": [
"data_batch = next(iter(train_loader))\n",
"outputs = model(**data_batch)\n",
"\n",
"for k, v in outputs.items():\n",
" try:\n",
" print(f\"{k}: {v.shape}\")\n",
" except AttributeError:\n",
" print(f\"{k}: {v}\")\n",
"\n",
"print(\"\\nForward pass successful!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 4 additions & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .adacare import AdaCare, AdaCareLayer
from .agent import Agent, AgentLayer
from .base_model import BaseModel
from .bottleneck_transformer import (
BottleneckTransformer,
MultimodalBottleneckTransformerEncoder
)
from .biot import BIOT
from .cnn import CNN, CNNLayer
from .concare import ConCare, ConCareLayer
Expand Down
Loading