Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d1f97af
init generators commit
chufangao Jun 15, 2025
ee8c52c
base
Jul 16, 2025
00f10c2
Stab at implementation
Jul 16, 2025
b666f82
Misc. changes for testing
Jul 27, 2025
ec4f23d
Remove testing logs
Jul 27, 2025
4ce8e21
Clean up things a bit
Jul 27, 2025
b1584fd
Clean up hardcoded file path
Jul 27, 2025
d374603
Remove testing files from PR
Jul 27, 2025
4f456f9
Init model properly
Jul 27, 2025
56380f6
Update comments
Jul 27, 2025
5d4ede6
Add HALO generator with training and generation examples
jalengg Feb 4, 2026
d2b8da3
Remove non-HALO README changes
jalengg Feb 16, 2026
97050b8
Create HALO Colab notebook structure with headers
jalengg Feb 16, 2026
58ef738
Add setup and installation cells to HALO notebook
jalengg Feb 16, 2026
21f394f
Add configuration, data upload, training, generation, and results cel…
jalengg Feb 16, 2026
b1458fe
Add README documentation for HALO Colab notebook
jalengg Feb 16, 2026
702d65c
Fix installation cell to detect pip failures
jalengg Feb 16, 2026
261f819
Remove pandas<2 constraint for Python 3.12 compatibility
jalengg Feb 16, 2026
4acf2f2
Fix MIMIC-III file upload issues in Colab notebook
jalengg Feb 16, 2026
b8b4c96
Add missing __init__.py to halo_resources module
jalengg Feb 16, 2026
564cf0a
Add --no-cache-dir to pip install for latest code
jalengg Feb 16, 2026
8002123
Fix path concatenation bug in HALO_MIMIC3Dataset
jalengg Feb 16, 2026
2d4fbbd
Add MANIFEST.in to include YAML config files in package
jalengg Feb 16, 2026
6ce060d
Fix YAML config packaging: use package_data instead of MANIFEST.in
jalengg Feb 17, 2026
2418788
Add install timestamp to Colab notebook success message
jalengg Feb 17, 2026
f1ceb35
Add last-updated timestamp to Colab notebook header
jalengg Feb 17, 2026
663dbb8
Use human-readable timestamp format in notebook header
jalengg Feb 17, 2026
bc0f41b
Fix pkl file path concatenation bug in HALO_MIMIC3Dataset
jalengg Feb 17, 2026
200e693
added trailing slash
shiitavie Feb 18, 2026
76de88e
added trailing slash
shiitavie Feb 18, 2026
5ec2c42
format string error
shiitavie Feb 18, 2026
bb5de81
remove assertion (issue #23)
shiitavie Feb 19, 2026
4040422
fix: complete merge - add missing processor files from upstream/master
jalengg Feb 23, 2026
0c9e973
chore: merge upstream/master, resolve processor/model conflicts
jalengg Feb 23, 2026
a864781
feat: add halo_generation task function (HaloGenerationMIMIC3/4)
jalengg Feb 23, 2026
fe08005
refactor: make HALO inherit BaseModel with forward() and train_model()
jalengg Feb 25, 2026
6a814d1
test: add synthesize_dataset coverage for HALO
jalengg Feb 25, 2026
661ec69
fix: collate_fn for variable visit lengths and end-token position in …
jalengg Feb 25, 2026
5fb8ce0
feat: update halo training example to PyHealth 2.0 API
jalengg Feb 25, 2026
8892dda
feat: update halo generation example to PyHealth 2.0 API
jalengg Feb 25, 2026
7df241f
feat: remove HALO_MIMIC3Dataset (replaced by HaloGenerationMIMIC3 task)
jalengg Feb 25, 2026
299d272
docs: update HALO docstrings to Google/PyHealth style
jalengg Feb 25, 2026
4dcff7c
docs: fix synthesize_dataset Returns style and dataset type annotation
jalengg Feb 25, 2026
6142476
test: add HALO end-to-end integration tests
jalengg Feb 25, 2026
1b744eb
test: fix tearDown cleanup, env var path, and relative bootstrap paths
jalengg Feb 25, 2026
d5248de
test: guard integration test against sys.modules stub contamination f…
jalengg Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ leaderboard/rtd_token.txt

# locally pre-trained models
pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model

# local testing files
halo_testing/
halo_testing_script.py
test_halo_model.slurm

data/physionet.org/

# VSCode settings
.vscode/
.vscode/
55 changes: 55 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# PyHealth Examples

This directory contains example scripts and notebooks for using PyHealth.

## HALO Synthetic Data Generation

### Google Colab Notebook (No Cluster Required)

**File**: `halo_mimic3_colab.ipynb`

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sunlabuiuc/PyHealth/blob/master/examples/halo_mimic3_colab.ipynb)

Train HALO and generate synthetic MIMIC-III data directly in your browser using Google Colab.

**Requirements**:
- Google account (for Colab)
- MIMIC-III access from PhysioNet
- Files: ADMISSIONS.csv, DIAGNOSES_ICD.csv, PATIENTS.csv, patient_ids.txt

**Quick Start**:
1. Open `halo_mimic3_colab.ipynb` in Google Colab
2. Enable GPU (Runtime → Change runtime type → GPU)
3. Run cells in order
4. Upload your MIMIC-III files when prompted
5. Download synthetic data CSV

**Demo vs Production**:
- **Demo** (default): 5 epochs, 1K samples, ~30 min
- **Production**: 80 epochs, 10K samples, ~6-10 hours (change configuration)

**Features**:
- Google Drive integration for persistence
- Resume capability if session times out
- Automatic checkpoint saving
- CSV output format
- Data quality validation

### Cluster Training (SLURM)

**Files**:
- `slurm/train_halo_mimic3.slurm` - Training script
- `slurm/generate_halo_mimic3.slurm` - Generation script
- `halo_mimic3_training.py` - Python training code
- `generate_synthetic_mimic3_halo.py` - Python generation code

For users with access to GPU clusters. See individual script headers for usage.

**Example**:
```bash
# Train
sbatch slurm/train_halo_mimic3.slurm

# Generate
sbatch slurm/generate_halo_mimic3.slurm
```
150 changes: 150 additions & 0 deletions examples/generate_synthetic_mimic3_halo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
Example: Generate synthetic MIMIC-III patients using a trained HALO checkpoint.

Loads MIMIC3Dataset with the halo_generation_mimic3_fn task (identical to
training) so that the vocabulary is reconstructed, then loads the saved
checkpoint and calls model.synthesize_dataset(). Output is saved as JSON.

Usage:
python examples/generate_synthetic_mimic3_halo.py
python examples/generate_synthetic_mimic3_halo.py --save_dir ./my_save/ --num_samples 500
"""

import argparse
import json
import os

import torch

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models.generators.halo import HALO
from pyhealth.tasks import halo_generation_mimic3_fn


def parse_args():
parser = argparse.ArgumentParser(
description="Generate synthetic MIMIC-III patients with HALO"
)
parser.add_argument(
"--mimic3_root",
default="/path/to/mimic3",
help="Root directory of MIMIC-III data (default: /path/to/mimic3)",
)
parser.add_argument(
"--save_dir",
default="./save/",
help="Directory containing the trained halo_model checkpoint (default: ./save/)",
)
parser.add_argument(
"--num_samples",
type=int,
default=1000,
help="Number of synthetic patients to generate (default: 1000)",
)
parser.add_argument(
"--output",
default="synthetic_patients.json",
help="Output JSON file path (default: synthetic_patients.json)",
)
return parser.parse_args()


def main():
args = parse_args()

# ------------------------------------------------------------------
# STEP 1: Load MIMIC-III dataset
# The dataset must use the same tables and code_mapping as training
# so that the vocabulary is identical.
# ------------------------------------------------------------------
print("Loading MIMIC-III dataset...")
base_dataset = MIMIC3Dataset(
root=args.mimic3_root,
tables=["diagnoses_icd"], # If you trained with different tables=, update this to match.
code_mapping={},
dev=False,
refresh_cache=False,
)
print(f" Loaded {len(base_dataset.patients)} patients")

# ------------------------------------------------------------------
# STEP 2: Apply the HALO generation task
# set_task builds the vocabulary via NestedSequenceProcessor — must
# match the task used during training exactly.
# ------------------------------------------------------------------
print("Applying HALO generation task...")
sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn)
print(f" {len(sample_dataset)} samples after task filtering")

# ------------------------------------------------------------------
# STEP 3: Instantiate HALO with the same hyperparameters as training
# The model constructor uses the dataset to determine vocab sizes;
# the weights are loaded from the checkpoint immediately after.
# ------------------------------------------------------------------
print("Initializing HALO model...")
model = HALO(
dataset=sample_dataset,
embed_dim=768,
n_heads=12,
n_layers=12,
n_ctx=48,
batch_size=48,
epochs=50, # unused during generation; must match training for checkpoint compatibility
pos_loss_weight=None,
lr=1e-4,
save_dir=args.save_dir,
)

# ------------------------------------------------------------------
# STEP 4: Load trained checkpoint
# The training loop saves to save_dir/halo_model with keys
# "model" (halo_model state dict) and "optimizer".
# ------------------------------------------------------------------
checkpoint_path = os.path.join(args.save_dir, "halo_model")
print(f"Loading checkpoint from {checkpoint_path} ...")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found at {checkpoint_path}. "
"Train the model first with examples/halo_mimic3_training.py."
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.halo_model.load_state_dict(checkpoint["model"])
print(" Checkpoint loaded successfully")

# ------------------------------------------------------------------
# STEP 5: Generate synthetic patients
# synthesize_dataset returns List[Dict] where each dict has:
# "patient_id": "synthetic_N"
# "visits": [[code, ...], ...]
# ------------------------------------------------------------------
print(f"Generating {args.num_samples} synthetic patients...")
synthetic_data = model.synthesize_dataset(
num_samples=args.num_samples,
random_sampling=True,
)

# ------------------------------------------------------------------
# STEP 6: Save output as JSON
# ------------------------------------------------------------------
print(f"Saving output to {args.output} ...")
with open(args.output, "w") as f:
json.dump(synthetic_data, f, indent=2)

# ------------------------------------------------------------------
# STEP 7: Print summary statistics
# ------------------------------------------------------------------
total_patients = len(synthetic_data)
total_visits = sum(len(p["visits"]) for p in synthetic_data)
avg_visits = total_visits / total_patients if total_patients > 0 else 0.0

print("\n--- Generation Summary ---")
print(f" Patients generated : {total_patients}")
print(f" Total visits : {total_visits}")
print(f" Avg visits/patient : {avg_visits:.2f}")
print(f" Output saved to : {args.output}")
print("Done.")


if __name__ == "__main__":
main()
353 changes: 353 additions & 0 deletions examples/halo_mimic3_colab.ipynb

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions examples/halo_mimic3_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Example: Training HALO on MIMIC-III for synthetic EHR generation.
This script demonstrates how to train the HALO model using PyHealth's
standard dataset and task patterns. HALO learns to generate synthetic
patient visit sequences via autoregressive transformer training.
Usage:
python examples/halo_mimic3_training.py
Replace the ``root`` path below with the local path to your MIMIC-III
data directory before running.
"""

from pyhealth.datasets import MIMIC3Dataset, split_by_patient
from pyhealth.models.generators.halo import HALO
from pyhealth.tasks import halo_generation_mimic3_fn

# Step 1: Load MIMIC-III dataset
print("Loading MIMIC-III dataset...")
base_dataset = MIMIC3Dataset(
root="/path/to/mimic3",
tables=["diagnoses_icd"],
)
base_dataset.stats()

# Step 2: Set task for HALO generation
# halo_generation_mimic3_fn extracts diagnosis code sequences per patient.
# Each patient produces one sample with all their visits (admissions with
# at least one ICD-9 code). Patients with fewer than 2 qualifying visits
# are excluded.
print("Setting HALO generation task...")
sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn)
print(f"Samples after task: {len(sample_dataset)}")

# Step 3: Split dataset by patient (no patient appears in more than one split)
print("Splitting dataset...")
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# Step 4: Initialize HALO model
# The model derives vocabulary size automatically from the dataset's
# NestedSequenceProcessor. No manual vocabulary setup is needed.
print("Initializing HALO model...")
model = HALO(
dataset=sample_dataset,
embed_dim=768,
n_heads=12,
n_layers=12,
n_ctx=48,
batch_size=48,
epochs=50,
pos_loss_weight=None,
lr=1e-4,
save_dir="./save/",
)

# Step 5: Train using HALO's custom training loop
# HALO does not use the PyHealth Trainer; it has its own loop that
# validates after every epoch and saves the best checkpoint to save_dir.
print("Starting training...")
model.train_model(train_dataset, val_dataset)

print("Training complete. Best checkpoint saved to ./save/halo_model")
36 changes: 36 additions & 0 deletions examples/slurm/generate_halo_mimic3.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash
#SBATCH --job-name=halo_generate
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=4
#SBATCH --mem=16G
#SBATCH --time=2:00:00
#SBATCH --output=/scratch/%u/logs/halo_generate_%j.out

# Canonical SLURM script for generating synthetic data with HALO
# Adjust paths, partition names, and resource allocations for your cluster

# Navigate to working directory
cd "${SLURM_SUBMIT_DIR}" || exit 1

echo "SLURM_JOB_ID: ${SLURM_JOB_ID}"
echo "Starting HALO generation at: $(date)"
echo "========================================"

# Activate your Python environment
# Example: conda activate pyhealth
# Example: source venv/bin/activate

# Set Python path if needed
# export PYTHONPATH=/path/to/PyHealth:${PYTHONPATH}

# Generation script
python examples/generate_synthetic_mimic3_halo.py \
--checkpoint_dir /scratch/jalenj4/halo_results/ \
--checkpoint_file halo_model_best \
--output_pkl /scratch/jalenj4/halo_results/synthetic/halo_synthetic_10k.pkl \
--output_csv /scratch/jalenj4/halo_results/synthetic/halo_synthetic_10k.csv \
--n_samples 10000

echo "========================================"
echo "Generation completed at: $(date)"
38 changes: 38 additions & 0 deletions examples/slurm/train_halo_mimic3.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash
#SBATCH --job-name=halo_train
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=4
#SBATCH --mem=32G
#SBATCH --time=12:00:00
#SBATCH --output=/scratch/%u/logs/halo_train_%j.out

# Canonical SLURM script for training HALO on MIMIC-III
# Adjust paths, partition names, and resource allocations for your cluster

# Navigate to working directory
cd "${SLURM_SUBMIT_DIR}" || exit 1

echo "SLURM_JOB_ID: ${SLURM_JOB_ID}"
echo "Starting HALO training at: $(date)"
echo "========================================"

# Activate your Python environment
# Example: conda activate pyhealth
# Example: source venv/bin/activate

# Set Python path if needed
# export PYTHONPATH=/path/to/PyHealth:${PYTHONPATH}

# Training script
python examples/halo_mimic3_training.py \
--mimic3_dir /u/jalenj4/pehr_scratch/data_files_train/ \
--output_dir /scratch/jalenj4/halo_results/ \
--epochs 80 \
--batch_size 48 \
--learning_rate 0.0001 \
--save_best \
--save_final

echo "========================================"
echo "Training completed at: $(date)"
Loading