Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ Available Datasets
datasets/pyhealth.datasets.BMDHSDataset
datasets/pyhealth.datasets.COVID19CXRDataset
datasets/pyhealth.datasets.ChestXray14Dataset
datasets/pyhealth.datasets.PhysioNetDeIDDataset
datasets/pyhealth.datasets.TUABDataset
datasets/pyhealth.datasets.TUEVDataset
datasets/pyhealth.datasets.ClinVarDataset
Expand Down
9 changes: 9 additions & 0 deletions docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pyhealth.datasets.PhysioNetDeIDDataset
=======================================

The PhysioNet De-Identification dataset. For more information see `here <https://physionet.org/content/deidentifiedmedicaltext/1.0/>`_. Access requires PhysioNet credentialing.

.. autoclass:: pyhealth.datasets.PhysioNetDeIDDataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ API Reference
models/pyhealth.models.GNN
models/pyhealth.models.Transformer
models/pyhealth.models.TransformersModel
models/pyhealth.models.TransformerDeID
models/pyhealth.models.RETAIN
models/pyhealth.models.GAMENet
models/pyhealth.models.GraphCare
Expand Down
9 changes: 9 additions & 0 deletions docs/api/models/pyhealth.models.TransformerDeID.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pyhealth.models.TransformerDeID
===================================

Transformer-based token classifier for clinical text de-identification.

.. autoclass:: pyhealth.models.TransformerDeID
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ Available Tasks
Sleep Staging v2 <tasks/pyhealth.tasks.sleep_staging_v2>
Benchmark EHRShot <tasks/pyhealth.tasks.benchmark_ehrshot>
ChestX-ray14 Binary Classification <tasks/pyhealth.tasks.ChestXray14BinaryClassification>
De-Identification NER <tasks/pyhealth.tasks.DeIDNERTask>
ChestX-ray14 Multilabel Classification <tasks/pyhealth.tasks.ChestXray14MultilabelClassification>
Variant Classification (ClinVar) <tasks/pyhealth.tasks.VariantClassificationClinVar>
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.DeIDNERTask
=======================================

.. autoclass:: pyhealth.tasks.DeIDNERTask
:members:
:undoc-members:
:show-inheritance:
191 changes: 191 additions & 0 deletions examples/physionet_deid_ner_transformer_deid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Train and evaluate TransformerDeID on PhysioNet de-identification.

End-to-end example: load data, train BERT-base on token-level NER for
PHI detection, and report binary (PHI vs non-PHI) precision/recall/F1.

Paper: Johnson et al. "Deidentification of free-text medical records
using pre-trained bidirectional transformers." CHIL, 2020.

Script structure follows examples/cardiology_detection_isAR_SparcNet.py.

Hyperparameters follow the paper (Section 3.4):
- Learning rate: 5e-5
- Batch size: 8
- Epochs: 3
- Weight decay: 0.01

Ablation results (3 epochs, 80/10/10 patient split, seed=42):

Config Precision Recall F1
BERT, no window 95.1% 70.3% 80.8%
BERT, win=64/32 94.1% 69.0% 79.6%
BERT, win=100/60 86.9% 75.7% 80.9%
BERT, win=200/100 94.7% 69.4% 80.1%
RoBERTa, no window 98.1% 64.7% 78.0%
RoBERTa, win=100/60 82.6% 68.6% 75.0%

BERT with window=100/60 achieves the best F1 (80.9%), matching the
paper's window configuration. Windowing improves recall by allowing
BERT to see tokens beyond the 512 truncation limit. RoBERTa has
higher precision but lower recall than BERT across configurations.

Usage:
python examples/physionet_deid_ner_transformer_deid.py \
--data_root path/to/deidentifiedmedicaltext/1.0

# With windowing (paper Section 3.3):
python examples/physionet_deid_ner_transformer_deid.py \
--data_root path/to/data --window_size 100 --window_overlap 60

# With RoBERTa:
python examples/physionet_deid_ner_transformer_deid.py \
--data_root path/to/data --model_name roberta-base

Author:
Matt McKenna (mtm16@illinois.edu)
"""

import argparse
from collections import defaultdict

import numpy as np
import torch
from sklearn.metrics import precision_score, recall_score, f1_score

from pyhealth.datasets import PhysioNetDeIDDataset, get_dataloader
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.models.transformer_deid import (
IGNORE_INDEX,
TransformerDeID,
)
from pyhealth.tasks import DeIDNERTask
from pyhealth.trainer import Trainer


def compute_metrics(model, dataloader):
"""Binary PHI vs non-PHI token-level metrics with window merging.

When windowing is used, multiple windows may cover the same token.
We merge by taking the non-O prediction with highest probability
(paper Section 3.3). Without windowing, each token appears once
so no merging is needed.
"""
# Collect per-token gold labels and prediction probabilities,
# keyed by (patient_id, note_id, absolute_token_position).
token_gold = {}
token_preds = defaultdict(list)

model.eval()
with torch.no_grad():
for batch in dataloader:
result = model(**batch)
probs = result["y_prob"] # (batch, seq_len, num_labels)
labels = result["y_true"] # (batch, seq_len)
patient_ids = batch["patient_id"]
note_ids = batch["note_id"]
token_starts = batch["token_start"]

for i in range(len(patient_ids)):
pid = patient_ids[i]
nid = note_ids[i]
start = int(token_starts[i])
word_idx = 0
for j in range(labels.shape[1]):
if labels[i, j].item() == IGNORE_INDEX:
continue
key = (pid, nid, start + word_idx)
token_gold[key] = labels[i, j].item()
token_preds[key].append(probs[i, j].cpu().numpy())
word_idx += 1

# Merge overlapping predictions (paper Section 3.3):
# if any window predicts non-O, take the non-O with highest score.
all_true, all_pred = [], []
for key in sorted(token_gold):
all_true.append(token_gold[key])
preds = token_preds[key]
non_o = [(p, p[1:].max()) for p in preds if np.argmax(p) != 0]
if non_o:
merged = max(non_o, key=lambda x: x[1])[0]
else:
merged = np.mean(preds, axis=0)
all_pred.append(int(np.argmax(merged)))

# Convert to binary: O (index 0) = 0, any PHI = 1.
true_bin = [0 if t == 0 else 1 for t in all_true]
pred_bin = [0 if p == 0 else 1 for p in all_pred]
return {
"precision": precision_score(true_bin, pred_bin),
"recall": recall_score(true_bin, pred_bin),
"f1": f1_score(true_bin, pred_bin),
}


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
required=True,
help="Path to deidentifiedmedicaltext/1.0 directory",
)
parser.add_argument("--model_name", type=str, default="bert-base-uncased")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--window_size", type=int, default=None,
help="Token window size (default: no windowing)")
parser.add_argument("--window_overlap", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

# 1. Load dataset and set task.
print("Loading dataset...")
dataset = PhysioNetDeIDDataset(root=args.data_root)
task = DeIDNERTask(
window_size=args.window_size,
window_overlap=args.window_overlap,
)
samples = dataset.set_task(task)
print(f" Patients: {len(dataset.unique_patient_ids)}, Samples: {len(samples)}")

# 2. Split by patient (80/10/10) so no patient's notes appear in
# both train and test.
train_data, val_data, test_data = split_by_patient(
samples, [0.8, 0.1, 0.1], seed=args.seed
)
train_loader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True)
val_loader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False)
test_loader = get_dataloader(test_data, batch_size=args.batch_size, shuffle=False)

# 3. Create model.
model = TransformerDeID(
dataset=samples,
model_name=args.model_name,
)

# 4. Train using PyHealth's Trainer.
device = "cuda" if torch.cuda.is_available() else "cpu"
trainer = Trainer(model=model, device=device)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=args.epochs,
optimizer_class=torch.optim.AdamW,
optimizer_params={"lr": args.lr},
weight_decay=0.01,
monitor="loss",
monitor_criterion="min",
)

# 5. Evaluate on test set.
print("\n=== Test Set Results (binary PHI vs non-PHI) ===")
metrics = compute_metrics(model, test_loader)
for k, v in metrics.items():
print(f" {k}: {v:.4f}")

samples.close()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs):
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .physionet_deid import PhysioNetDeIDDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
Expand Down
10 changes: 10 additions & 0 deletions pyhealth/datasets/configs/physionet_deid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: "1.0"
tables:
physionet_deid:
file_path: "physionet_deid_metadata.csv"
patient_id: "patient_id"
timestamp: null
attributes:
- "note_id"
- "text"
- "labels"
Loading