diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..8d9a59d21 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -238,6 +238,7 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.PhysioNetDeIDDataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset datasets/pyhealth.datasets.ClinVarDataset diff --git a/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst b/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst new file mode 100644 index 000000000..4e04cd629 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.PhysioNetDeIDDataset +======================================= + +The PhysioNet De-Identification dataset. For more information see `here `_. Access requires PhysioNet credentialing. + +.. autoclass:: pyhealth.datasets.PhysioNetDeIDDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..166402e86 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -177,6 +177,7 @@ API Reference models/pyhealth.models.GNN models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel + models/pyhealth.models.TransformerDeID models/pyhealth.models.RETAIN models/pyhealth.models.GAMENet models/pyhealth.models.GraphCare diff --git a/docs/api/models/pyhealth.models.TransformerDeID.rst b/docs/api/models/pyhealth.models.TransformerDeID.rst new file mode 100644 index 000000000..d07aa94aa --- /dev/null +++ b/docs/api/models/pyhealth.models.TransformerDeID.rst @@ -0,0 +1,9 @@ +pyhealth.models.TransformerDeID +=================================== + +Transformer-based token classifier for clinical text de-identification. + +.. autoclass:: pyhealth.models.TransformerDeID + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..23a4e06e5 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -224,6 +224,7 @@ Available Tasks Sleep Staging v2 Benchmark EHRShot ChestX-ray14 Binary Classification + De-Identification NER ChestX-ray14 Multilabel Classification Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) diff --git a/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst b/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst new file mode 100644 index 000000000..2b7428f6e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.DeIDNERTask +======================================= + +.. autoclass:: pyhealth.tasks.DeIDNERTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py new file mode 100644 index 000000000..eaa421bd9 --- /dev/null +++ b/examples/physionet_deid_ner_transformer_deid.py @@ -0,0 +1,191 @@ +"""Train and evaluate TransformerDeID on PhysioNet de-identification. + +End-to-end example: load data, train BERT-base on token-level NER for +PHI detection, and report binary (PHI vs non-PHI) precision/recall/F1. + +Paper: Johnson et al. "Deidentification of free-text medical records + using pre-trained bidirectional transformers." CHIL, 2020. + +Script structure follows examples/cardiology_detection_isAR_SparcNet.py. + +Hyperparameters follow the paper (Section 3.4): + - Learning rate: 5e-5 + - Batch size: 8 + - Epochs: 3 + - Weight decay: 0.01 + +Ablation results (3 epochs, 80/10/10 patient split, seed=42): + + Config Precision Recall F1 + BERT, no window 95.1% 70.3% 80.8% + BERT, win=64/32 94.1% 69.0% 79.6% + BERT, win=100/60 86.9% 75.7% 80.9% + BERT, win=200/100 94.7% 69.4% 80.1% + RoBERTa, no window 98.1% 64.7% 78.0% + RoBERTa, win=100/60 82.6% 68.6% 75.0% + + BERT with window=100/60 achieves the best F1 (80.9%), matching the + paper's window configuration. Windowing improves recall by allowing + BERT to see tokens beyond the 512 truncation limit. RoBERTa has + higher precision but lower recall than BERT across configurations. + +Usage: + python examples/physionet_deid_ner_transformer_deid.py \ + --data_root path/to/deidentifiedmedicaltext/1.0 + + # With windowing (paper Section 3.3): + python examples/physionet_deid_ner_transformer_deid.py \ + --data_root path/to/data --window_size 100 --window_overlap 60 + + # With RoBERTa: + python examples/physionet_deid_ner_transformer_deid.py \ + --data_root path/to/data --model_name roberta-base + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import argparse +from collections import defaultdict + +import numpy as np +import torch +from sklearn.metrics import precision_score, recall_score, f1_score + +from pyhealth.datasets import PhysioNetDeIDDataset, get_dataloader +from pyhealth.datasets.splitter import split_by_patient +from pyhealth.models.transformer_deid import ( + IGNORE_INDEX, + TransformerDeID, +) +from pyhealth.tasks import DeIDNERTask +from pyhealth.trainer import Trainer + + +def compute_metrics(model, dataloader): + """Binary PHI vs non-PHI token-level metrics with window merging. + + When windowing is used, multiple windows may cover the same token. + We merge by taking the non-O prediction with highest probability + (paper Section 3.3). Without windowing, each token appears once + so no merging is needed. + """ + # Collect per-token gold labels and prediction probabilities, + # keyed by (patient_id, note_id, absolute_token_position). + token_gold = {} + token_preds = defaultdict(list) + + model.eval() + with torch.no_grad(): + for batch in dataloader: + result = model(**batch) + probs = result["y_prob"] # (batch, seq_len, num_labels) + labels = result["y_true"] # (batch, seq_len) + patient_ids = batch["patient_id"] + note_ids = batch["note_id"] + token_starts = batch["token_start"] + + for i in range(len(patient_ids)): + pid = patient_ids[i] + nid = note_ids[i] + start = int(token_starts[i]) + word_idx = 0 + for j in range(labels.shape[1]): + if labels[i, j].item() == IGNORE_INDEX: + continue + key = (pid, nid, start + word_idx) + token_gold[key] = labels[i, j].item() + token_preds[key].append(probs[i, j].cpu().numpy()) + word_idx += 1 + + # Merge overlapping predictions (paper Section 3.3): + # if any window predicts non-O, take the non-O with highest score. + all_true, all_pred = [], [] + for key in sorted(token_gold): + all_true.append(token_gold[key]) + preds = token_preds[key] + non_o = [(p, p[1:].max()) for p in preds if np.argmax(p) != 0] + if non_o: + merged = max(non_o, key=lambda x: x[1])[0] + else: + merged = np.mean(preds, axis=0) + all_pred.append(int(np.argmax(merged))) + + # Convert to binary: O (index 0) = 0, any PHI = 1. + true_bin = [0 if t == 0 else 1 for t in all_true] + pred_bin = [0 if p == 0 else 1 for p in all_pred] + return { + "precision": precision_score(true_bin, pred_bin), + "recall": recall_score(true_bin, pred_bin), + "f1": f1_score(true_bin, pred_bin), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + required=True, + help="Path to deidentifiedmedicaltext/1.0 directory", + ) + parser.add_argument("--model_name", type=str, default="bert-base-uncased") + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--window_size", type=int, default=None, + help="Token window size (default: no windowing)") + parser.add_argument("--window_overlap", type=int, default=0) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + # 1. Load dataset and set task. + print("Loading dataset...") + dataset = PhysioNetDeIDDataset(root=args.data_root) + task = DeIDNERTask( + window_size=args.window_size, + window_overlap=args.window_overlap, + ) + samples = dataset.set_task(task) + print(f" Patients: {len(dataset.unique_patient_ids)}, Samples: {len(samples)}") + + # 2. Split by patient (80/10/10) so no patient's notes appear in + # both train and test. + train_data, val_data, test_data = split_by_patient( + samples, [0.8, 0.1, 0.1], seed=args.seed + ) + train_loader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader(test_data, batch_size=args.batch_size, shuffle=False) + + # 3. Create model. + model = TransformerDeID( + dataset=samples, + model_name=args.model_name, + ) + + # 4. Train using PyHealth's Trainer. + device = "cuda" if torch.cuda.is_available() else "cpu" + trainer = Trainer(model=model, device=device) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": args.lr}, + weight_decay=0.01, + monitor="loss", + monitor_criterion="min", + ) + + # 5. Evaluate on test set. + print("\n=== Test Set Results (binary PHI vs non-PHI) ===") + metrics = compute_metrics(model, test_loader) + for k, v in metrics.items(): + print(f" {k}: {v:.4f}") + + samples.close() + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..50b1b3887 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset +from .physionet_deid import PhysioNetDeIDDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset diff --git a/pyhealth/datasets/configs/physionet_deid.yaml b/pyhealth/datasets/configs/physionet_deid.yaml new file mode 100644 index 000000000..2054ab809 --- /dev/null +++ b/pyhealth/datasets/configs/physionet_deid.yaml @@ -0,0 +1,10 @@ +version: "1.0" +tables: + physionet_deid: + file_path: "physionet_deid_metadata.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "note_id" + - "text" + - "labels" diff --git a/pyhealth/datasets/physionet_deid.py b/pyhealth/datasets/physionet_deid.py new file mode 100644 index 000000000..5bd6acc95 --- /dev/null +++ b/pyhealth/datasets/physionet_deid.py @@ -0,0 +1,352 @@ +""" +PyHealth dataset for the PhysioNet De-Identification dataset. + +Dataset link: + https://physionet.org/content/deidentifiedmedicaltext/1.0/ + +Dataset paper: (please cite if you use this dataset) + Neamatullah, Ishna, et al. "Automated de-identification of free-text + medical records." BMC Medical Informatics and Decision Making 8.1 (2008). + +Paper link: + https://doi.org/10.1186/1472-6947-8-32 + +PHI category mapping in classify_phi() inspired by the label groupings +in the bert-deid reference implementation by Johnson et al.: + https://github.com/alistairewj/bert-deid/blob/master/bert_deid/label.py + +Task paper: + Johnson, Alistair E.W., et al. "Deidentification of free-text medical + records using pre-trained bidirectional transformers." Proceedings of + the ACM Conference on Health, Inference, and Learning (CHIL), 2020. + +Author: + Matt McKenna (mtm16@illinois.edu) +""" +import logging +import os +import re +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import pandas as pd + +from pyhealth.datasets import BaseDataset + +logger = logging.getLogger(__name__) + +# -- Record parsing regexes -- + +_RECORD_START = re.compile(r"START_OF_RECORD=(\d+)\|\|\|\|(\d+)\|\|\|\|") +_RECORD_END = re.compile(r"\|\|\|\|END_OF_RECORD") +_PHI_TAG = re.compile(r"\[\*\*(.+?)\*\*\]", re.DOTALL) +_PHI_SPLIT = re.compile(r"\[\*\*(?:.+?)\*\*\]", re.DOTALL) + + +def _parse_file(path: Path) -> Dict[Tuple[str, str], str]: + """Parse a PhysioNet record file into {(patient_id, note_id): body}. + + Args: + path: Path to id.text or id.res file. + + Returns: + Dictionary mapping (patient_id, note_id) to note body text. + """ + raw = path.read_text(encoding="utf-8", errors="ignore") + out: Dict[Tuple[str, str], str] = {} + for m in _RECORD_START.finditer(raw): + pid, nid = m.group(1), m.group(2) + body_start = m.end() + end_m = _RECORD_END.search(raw, body_start) + body = raw[body_start : end_m.start() if end_m else len(raw)] + out[(pid, nid)] = body.strip() + return out + + +def classify_phi(raw: str) -> str: + """Map raw [**...**] tag text to one of the 7 PHI categories. + + Args: + raw: The text inside a [**...**] tag. + + Returns: + One of: AGE, DATE, CONTACT, LOCATION, ID, PROFESSION, NAME. + """ + t = re.sub(r"[^a-z0-9 ]+", " ", raw.strip().lower()).strip() + + if any(k in t for k in ("year old", " yo ", " age ")): + return "AGE" + if any(k in t for k in ("date", "month", "day", "year", "holiday")): + return "DATE" + if re.fullmatch(r"[\d]{1,2}[ \-/][\d]{1,2}([ \-/][\d]{2,4})?", t): + return "DATE" + if re.fullmatch(r"[\d]+", raw.strip()): + return "DATE" + if any(k in t for k in ("phone", "fax", "email", "pager", "contact")): + return "CONTACT" + if any( + k in t + for k in ( + "hospital", + "location", + "street", + "county", + "state", + "country", + "zip", + "address", + "ward", + "room", + ) + ): + return "LOCATION" + if any( + k in t + for k in ( + "mrn", + "medical record", + "record number", + "ssn", + "account", + "serial", + "unit no", + "unit number", + "identifier", + ) + ): + return "ID" + if " id " in f" {t} ": + return "ID" + if any( + k in t + for k in ( + "doctor", + " dr ", + " md ", + "nurse", + "attending", + "resident", + "profession", + "service", + "provider", + ) + ): + return "PROFESSION" + if any( + k in t + for k in ( + "name", + "initial", + "alias", + "patient", + "first name", + "last name", + ) + ): + return "NAME" + return "NAME" # fallback + + +def phi_spans_in_original( + orig: str, deid: str +) -> List[Tuple[int, int, str]]: + """Find PHI character spans in orig by anchoring on non-PHI chunks. + + Uses non-PHI text from the de-identified version as anchors to locate + where the original PHI text appears in the original note. + + Args: + orig: Original note text (with real PHI). + deid: De-identified note text (PHI replaced with [**...**] tags). + + Returns: + List of (char_start, char_end, phi_category) tuples. + """ + parts = _PHI_SPLIT.split(deid) + tags = _PHI_TAG.findall(deid) + + spans: List[Tuple[int, int, str]] = [] + pos = 0 + + for i, tag_inner in enumerate(tags): + before = parts[i] + if before: + idx = orig.find(before, pos) + pos = (idx + len(before)) if idx != -1 else (pos + len(before)) + + phi_start = pos + + after = parts[i + 1] + if after: + idx = orig.find(after, phi_start) + phi_end = idx if idx != -1 else phi_start + else: + phi_end = len(orig) + + if phi_end > phi_start: + spans.append((phi_start, phi_end, classify_phi(tag_inner))) + + pos = phi_end + + return spans + + +def bio_tag( + text: str, spans: List[Tuple[int, int, str]] +) -> List[Tuple[str, str]]: + """Whitespace-tokenize text and assign BIO labels from char-level spans. + + Args: + text: Original note text. + spans: List of (char_start, char_end, phi_category) tuples. + + Returns: + List of (word, label) tuples. + """ + char_label = ["O"] * len(text) + for start, end, cat in spans: + for i in range(start, min(end, len(text))): + char_label[i] = cat + + result: List[Tuple[str, str]] = [] + for m in re.finditer(r"\S+", text): + w_start, w_end = m.start(), m.end() + word = m.group() + # Collect PHI categories for this token's characters, ignoring O's. + # If empty, the token has no PHI and we label it O. + cats = [c for c in char_label[w_start:w_end] if c != "O"] + if not cats: + result.append((word, "O")) + continue + # Pick the most common category. Handles rare cases where a token + # spans two PHI types, e.g. "Smith01/15" -> chars are NAME+DATE, + # majority wins (DATE). + cat = max(set(cats), key=cats.count) + # B = beginning of a new entity, I = continuation of the same one. + # Use "I" only if the previous token was the same category, + # e.g. "Tom"=B-NAME "Garcia"=I-NAME. Otherwise start a new "B". + prev_label = result[-1][1] if result else "O" + prefix = ( + "I" + if prev_label not in ("O",) and prev_label.endswith(cat) + else "B" + ) + result.append((word, f"{prefix}-{cat}")) + + return result + + +class PhysioNetDeIDDataset(BaseDataset): + """Dataset class for the PhysioNet De-Identification dataset. + + This dataset contains 2,434 nursing notes from 163 patients. + Each note has original text with PHI (protected health information) + and a de-identified version with [**...**] tags marking PHI spans. + + The dataset parses both files to produce token-level BIO labels + for 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME, + PROFESSION. + + Data access requires PhysioNet credentialing: + 1. Create a PhysioNet account at https://physionet.org + 2. Complete the required CITI training + 3. Sign the data use agreement + 4. Download from + https://physionet.org/content/deidentifiedmedicaltext/1.0/ + + Attributes: + root (str): Root directory containing id.text and id.res files. + dataset_name (str): Name of the dataset. + + Example:: + >>> dataset = PhysioNetDeIDDataset(root="./data/physionet_deid") + """ + + def __init__( + self, + root: str = ".", + config_path: Optional[str] = str( + Path(__file__).parent / "configs" / "physionet_deid.yaml" + ), + **kwargs, + ) -> None: + """Initializes the PhysioNet De-Identification dataset. + + Args: + root: Root directory containing id.text and id.res files. + config_path: Path to the configuration file. + + Raises: + FileNotFoundError: If id.text or id.res not found in root. + + Example:: + >>> dataset = PhysioNetDeIDDataset(root="./data") + """ + self._verify_data(root) + self._index_data(root) + + super().__init__( + root=root, + tables=["physionet_deid"], + dataset_name="PhysioNetDeID", + config_path=config_path, + **kwargs, + ) + + def _verify_data(self, root: str) -> None: + """Verify that required data files exist. + + Args: + root: Root directory to check. + + Raises: + FileNotFoundError: If id.text or id.res is missing. + """ + for fname in ("id.text", "id.res"): + path = os.path.join(root, fname) + if not os.path.isfile(path): + raise FileNotFoundError( + f"Required file '{fname}' not found in {root}" + ) + + def _index_data(self, root: str) -> pd.DataFrame: + """Parse id.text and id.res into a CSV for BaseDataset to load. + + Args: + root: Root directory containing the data files. + + Returns: + DataFrame with columns: patient_id, note_id, text, labels. + """ + root_path = Path(root) + orig_records = _parse_file(root_path / "id.text") + deid_records = _parse_file(root_path / "id.res") + + rows = [] + for key in sorted( + orig_records, key=lambda k: (int(k[0]), int(k[1])) + ): + pid, nid = key + orig = orig_records[key] + # Missing key yields empty string (no deid version). + deid = deid_records.get(key, "") + spans = phi_spans_in_original(orig, deid) + tagged = bio_tag(orig, spans) + + tokens = " ".join(w for w, _ in tagged) + labels = " ".join(lbl for _, lbl in tagged) + + rows.append( + { + "patient_id": pid, + "note_id": nid, + "text": tokens, + "labels": labels, + } + ) + + df = pd.DataFrame(rows) + df.to_csv( + os.path.join(root, "physionet_deid_metadata.csv"), index=False + ) + return df diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..deabaf95c 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,6 +1,7 @@ from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare from .agent import Agent, AgentLayer from .base_model import BaseModel +from .transformer_deid import TransformerDeID from .biot import BIOT from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer diff --git a/pyhealth/models/transformer_deid.py b/pyhealth/models/transformer_deid.py new file mode 100644 index 000000000..964942f5a --- /dev/null +++ b/pyhealth/models/transformer_deid.py @@ -0,0 +1,263 @@ +""" +PyHealth model for transformer-based clinical text de-identification. + +Performs token-level NER to detect PHI (protected health information) +in clinical notes using a pre-trained transformer with a classification +head. + +Paper: Johnson, Alistair E.W., et al. "Deidentification of free-text + medical records using pre-trained bidirectional transformers." + Proceedings of the ACM Conference on Health, Inference, and + Learning (CHIL), 2020. + +Paper link: + https://doi.org/10.1145/3368555.3384455 + +Model structure (dropout + linear head) follows PyHealth's +TransformersModel (pyhealth/models/transformers_model.py), adapted +for token-level classification instead of sequence-level. + +Subword alignment follows the standard HuggingFace token +classification pattern (see BertForTokenClassification). + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import logging +from typing import Dict, List + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer + +from ..datasets import SampleDataset +from .base_model import BaseModel + +logger = logging.getLogger(__name__) + +# 7 PHI categories with BIO prefix, plus O for non-PHI. +LABEL_VOCAB = { + "O": 0, + "B-AGE": 1, "I-AGE": 2, + "B-CONTACT": 3, "I-CONTACT": 4, + "B-DATE": 5, "I-DATE": 6, + "B-ID": 7, "I-ID": 8, + "B-LOCATION": 9, "I-LOCATION": 10, + "B-NAME": 11, "I-NAME": 12, + "B-PROFESSION": 13, "I-PROFESSION": 14, +} + +# Cross-entropy ignores positions with this index (PyTorch convention). +IGNORE_INDEX = -100 + + +def align_labels( + word_ids: List[int | None], + word_labels: List[int], +) -> List[int]: + """Align word-level labels to subword tokens. + + BERT/RoBERTa tokenizers split words into subwords. For example, + "Smith" might become ["Sm", "##ith"]. This function assigns the + word's label to the first subtoken and IGNORE_INDEX to the rest, + so the loss function skips non-first subtokens. Special tokens + ([CLS], [SEP], padding) have word_id=None and also get + IGNORE_INDEX. + + Args: + word_ids: Output of tokenizer.word_ids(). None for special + tokens, integer word index for real tokens. + word_labels: Label index for each word in the original text. + + Returns: + List of label indices, one per subtoken. Non-first subtokens + and special tokens are set to IGNORE_INDEX (-100). + """ + aligned = [] + prev_word_id = None + for word_id in word_ids: + if word_id is None: + # Special token ([CLS], [SEP], padding). + aligned.append(IGNORE_INDEX) + elif word_id != prev_word_id: + # First subtoken of a word: use the word's label. + aligned.append(word_labels[word_id]) + else: + # Non-first subtoken: ignore during loss computation. + aligned.append(IGNORE_INDEX) + prev_word_id = word_id + return aligned + + +class TransformerDeID(BaseModel): + """Transformer-based token classifier for clinical text de-identification. + + Uses a pre-trained transformer encoder with a linear classification + head to predict BIO-tagged PHI labels for each token. + + Args: + dataset: A SampleDataset from set_task(). + model_name: HuggingFace model name. Default "bert-base-uncased". + max_length: Maximum token sequence length. Default 512. + dropout: Dropout rate for the classification head. Default 0.1. + + Examples: + >>> from pyhealth.datasets import PhysioNetDeIDDataset + >>> from pyhealth.tasks import DeIDNERTask + >>> from pyhealth.models import TransformerDeID + >>> dataset = PhysioNetDeIDDataset(root="/path/to/data") + >>> samples = dataset.set_task(DeIDNERTask()) + >>> model = TransformerDeID(dataset=samples) # BERT + >>> model = TransformerDeID(dataset=samples, model_name="roberta-base") + """ + + def __init__( + self, + dataset: SampleDataset, + model_name: str = "bert-base-uncased", + max_length: int = 512, + dropout: float = 0.1, + ): + super(TransformerDeID, self).__init__(dataset=dataset) + + assert len(self.feature_keys) == 1, ( + "TransformerDeID expects exactly one input feature (text)." + ) + assert len(self.label_keys) == 1, ( + "TransformerDeID expects exactly one label key." + ) + self.feature_key = self.feature_keys[0] + self.label_key = self.label_keys[0] + + self.model_name = model_name + self.max_length = max_length + self.label_vocab = LABEL_VOCAB + self.num_labels = len(LABEL_VOCAB) + + # add_prefix_space=True is required for RoBERTa when using + # is_split_into_words=True in the forward pass. + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, add_prefix_space=True + ) + self.encoder = AutoModel.from_pretrained( + model_name, + hidden_dropout_prob=dropout, + attention_probs_dropout_prob=dropout, + ) + hidden_size = self.encoder.config.hidden_size + self.dropout = nn.Dropout(dropout) + self.classifier = nn.Linear(hidden_size, self.num_labels) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass. + + Args: + **kwargs: Must contain self.feature_key (list of + space-joined token strings) and self.label_key + (list of space-joined BIO label strings). + + Returns: + Dict with keys: loss, logit, y_prob, y_true. + """ + texts: List[str] = kwargs[self.feature_key] + label_strings: List[str] = kwargs[self.label_key] + + # Tokenize with is_split_into_words=True so the tokenizer + # knows word boundaries and word_ids() works correctly. + words_batch = [t.split(" ") for t in texts] + encoding = self.tokenizer( + words_batch, + is_split_into_words=True, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + + # Convert word-level label strings to indices, then align + # to subword tokens. Positions that should be ignored during + # loss (special tokens, non-first subtokens, padding) get + # IGNORE_INDEX (-100), which cross-entropy skips. + aligned_labels = [] + for i, label_str in enumerate(label_strings): + word_labels = [ + self.label_vocab[lbl] for lbl in label_str.split(" ") + ] + word_ids = encoding.word_ids(batch_index=i) + aligned_labels.append(align_labels(word_ids, word_labels)) + + labels = torch.tensor(aligned_labels, dtype=torch.long) + + # Move to device + input_ids = encoding["input_ids"].to(self.device) + attention_mask = encoding["attention_mask"].to(self.device) + labels = labels.to(self.device) + + # Encoder -> dropout -> classifier (per-token logits) + hidden_states = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state + logits = self.classifier(self.dropout(hidden_states)) + + # Token-level cross-entropy, ignoring padded/special positions. + # We can't use BaseModel.get_loss_function() because it assumes + # one label per sample. Instead we call cross_entropy directly + # with ignore_index to skip special tokens and non-first subtokens. + # Flatten + ignore_index pattern from HuggingFace's + # BertForTokenClassification.forward(). + loss = nn.functional.cross_entropy( + logits.view(-1, self.num_labels), + labels.view(-1), + ignore_index=IGNORE_INDEX, + ) + + # Per-token probabilities via softmax. + y_prob = torch.softmax(logits, dim=-1) + + return { + "loss": loss, + "logit": logits, + "y_prob": y_prob, + "y_true": labels, + } + + def deidentify(self, text: str, redact: str = "[REDACTED]") -> str: + """Replace PHI in a clinical note with a redaction marker. + + Args: + text: Raw clinical note as a string. + redact: Replacement string for PHI tokens. + + Returns: + The note with PHI tokens replaced. + + Example:: + >>> model.deidentify("Patient John Smith was seen") + 'Patient [REDACTED] [REDACTED] was seen' + """ + words = text.split() + # Forward pass with dummy labels (all O) since we only + # need predictions, not loss. + dummy_labels = " ".join(["O"] * len(words)) + self.eval() + with torch.no_grad(): + result = self(text=[text], labels=[dummy_labels]) + + preds = result["logit"][0].argmax(dim=-1) + y_true = result["y_true"][0] + + # Map predictions back to words using the non-ignored positions. + word_idx = 0 + output = [] + for j in range(len(preds)): + if y_true[j].item() == IGNORE_INDEX: + continue + if preds[j].item() != 0: # non-O = PHI + output.append(redact) + else: + output.append(words[word_idx]) + word_idx += 1 + + return " ".join(output) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..a32618f9c 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -12,6 +12,7 @@ from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification +from .deid_ner import DeIDNERTask from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 from .drug_recommendation import ( DrugRecommendationEICU, diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py new file mode 100644 index 000000000..67b212bac --- /dev/null +++ b/pyhealth/tasks/deid_ner.py @@ -0,0 +1,118 @@ +""" +PyHealth task for NER-based de-identification of clinical text. + +Converts PhysioNet De-Identification dataset records into token-level +BIO-tagged NER samples for PHI detection. + +Dataset link: + https://physionet.org/content/deidentifiedmedicaltext/1.0/ + +Task paper: (please cite if you use this task) + Johnson, Alistair E.W., et al. "Deidentification of free-text medical + records using pre-trained bidirectional transformers." Proceedings of + the ACM Conference on Health, Inference, and Learning (CHIL), 2020. + +Paper link: + https://doi.org/10.1145/3368555.3384455 + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +from typing import Dict, List, Optional, Type, Union + +from pyhealth.data import Event, Patient +from pyhealth.processors.text_processor import TextProcessor +from pyhealth.tasks import BaseTask + + +class DeIDNERTask(BaseTask): + """Token-level NER task for clinical text de-identification. + + Each sample contains a list of tokens and their BIO labels over + 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME, + PROFESSION. + + Supports optional overlapping windowing (paper Section 3.3) to + handle notes longer than BERT's 512 token limit. + + Args: + window_size: If set, split notes into overlapping windows of + this many tokens. Default None (no windowing). + window_overlap: Number of tokens shared between consecutive + windows. Default 0. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import PhysioNetDeIDDataset + >>> from pyhealth.tasks import DeIDNERTask + >>> dataset = PhysioNetDeIDDataset(root="/path/to/data") + >>> task = DeIDNERTask() + >>> samples = dataset.set_task(task) + >>> task_windowed = DeIDNERTask(window_size=100, window_overlap=60) + >>> samples = dataset.set_task(task_windowed) + """ + + task_name: str = "DeIDNER" + input_schema: Dict[str, str] = {"text": "text"} + # Labels are kept as a space-joined string. The model splits and + # encodes them into label indices. + output_schema: Dict[str, Union[str, Type]] = {"labels": TextProcessor} + + def __init__( + self, + window_size: Optional[int] = None, + window_overlap: int = 0, + ): + self.window_size = window_size + self.window_overlap = window_overlap + + def __call__(self, patient: Patient) -> List[Dict]: + """Generate NER samples from a patient's clinical notes. + + Args: + patient: A Patient object with physionet_deid events. + + Returns: + List of dicts, each with 'text' (str) and + 'labels' (str) keys. Both are space-joined strings. + """ + events: List[Event] = patient.get_events( + event_type="physionet_deid" + ) + + samples = [] + for event in events: + note_id = event["note_id"] + words = event["text"].split(" ") + labels = event["labels"].split(" ") + + if self.window_size is None: + # No windowing: one sample per note. + samples.append({ + "patient_id": patient.patient_id, + "note_id": note_id, + "token_start": "0", + "text": event["text"], + "labels": event["labels"], + }) + else: + # Overlapping windows (paper Section 3.3). + step = self.window_size - self.window_overlap + idx = 0 + while idx < len(words): + end = min(idx + self.window_size, len(words)) + samples.append({ + "patient_id": patient.patient_id, + "note_id": note_id, + "token_start": str(idx), + "text": " ".join(words[idx:end]), + "labels": " ".join(labels[idx:end]), + }) + idx += step + + return samples diff --git a/test-resources/core/physionet_deid/id.res b/test-resources/core/physionet_deid/id.res new file mode 100644 index 000000000..1081b2717 --- /dev/null +++ b/test-resources/core/physionet_deid/id.res @@ -0,0 +1,12 @@ +START_OF_RECORD=10||||1|||| +Patient [**First Name 101**] [**Last Name 102**] was admitted on [**Date 103**] to [**Hospital 104**]. She is [**Age 105**] years old. Contact phone [**Phone 106**]. MRN [**Medical Record Number 107**]. +||||END_OF_RECORD +START_OF_RECORD=10||||2|||| +Seen by [**Doctor First Name 201**] [**Doctor Last Name 202**] in radiology. No acute findings. +||||END_OF_RECORD +START_OF_RECORD=20||||1|||| +Assessment unchanged. Vitals stable overnight. Continue current plan. +||||END_OF_RECORD +START_OF_RECORD=60||||1|||| +[**First Name 301**] [**Last Name 302**] presented to [**Hospital 303**] on [**Date 304**] for follow-up. He works as a [**Profession 305**]. Patient ID [**Medical Record Number 306**]. +||||END_OF_RECORD diff --git a/test-resources/core/physionet_deid/id.text b/test-resources/core/physionet_deid/id.text new file mode 100644 index 000000000..e61e9f09d --- /dev/null +++ b/test-resources/core/physionet_deid/id.text @@ -0,0 +1,12 @@ +START_OF_RECORD=10||||1|||| +Patient Jane Doe was admitted on 03/12/2098 to Springfield General Hospital. She is 72 years old. Contact phone 555-867-5309. MRN 00112233. +||||END_OF_RECORD +START_OF_RECORD=10||||2|||| +Seen by Dr. Robert Wells in radiology. No acute findings. +||||END_OF_RECORD +START_OF_RECORD=20||||1|||| +Assessment unchanged. Vitals stable overnight. Continue current plan. +||||END_OF_RECORD +START_OF_RECORD=60||||1|||| +Tom Garcia presented to Lakewood Clinic on 11/05/2097 for follow-up. He works as a plumber. Patient ID 99887766. +||||END_OF_RECORD diff --git a/tests/core/test_physionet_deid.py b/tests/core/test_physionet_deid.py new file mode 100644 index 000000000..6cfb47c6e --- /dev/null +++ b/tests/core/test_physionet_deid.py @@ -0,0 +1,187 @@ +""" +Unit tests for the PhysioNetDeIDDataset and DeIDNERTask classes. + +Author: + Matt McKenna (mtm16@illinois.edu) +""" +from pathlib import Path +import tempfile +import unittest + +from pyhealth.datasets import PhysioNetDeIDDataset +from pyhealth.tasks import DeIDNERTask + + +class TestPhysioNetDeIDDataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "physionet_deid" + ) + cls.cache_dir = tempfile.TemporaryDirectory() + cls.dataset = PhysioNetDeIDDataset( + root=str(cls.root), cache_dir=cls.cache_dir.name + ) + cls.task = DeIDNERTask() + cls.samples = cls.dataset.set_task(cls.task) + + @classmethod + def tearDownClass(cls): + cls.samples.close() + cls.cache_dir.cleanup() + metadata = cls.root / "physionet_deid_metadata.csv" + if metadata.exists(): + metadata.unlink() + + def test_num_patients(self): + self.assertEqual(len(self.dataset.unique_patient_ids), 3) + + def test_patient_ids(self): + ids = set(self.dataset.unique_patient_ids) + self.assertEqual(ids, {"10", "20", "60"}) + + def test_patient_10_has_two_notes(self): + events = self.dataset.get_patient("10").get_events() + self.assertEqual(len(events), 2) + + def test_patient_20_has_one_note(self): + events = self.dataset.get_patient("20").get_events() + self.assertEqual(len(events), 1) + + def test_patient_60_has_one_note(self): + events = self.dataset.get_patient("60").get_events() + self.assertEqual(len(events), 1) + + def test_patient_10_note1_has_tokens_and_labels(self): + events = self.dataset.get_patient("10").get_events() + note1 = events[0] + self.assertIn("text", note1) + self.assertIn("labels", note1) + + def test_patient_10_note1_token_count(self): + """Note 1 for patient 10 should have the right number of tokens.""" + events = self.dataset.get_patient("10").get_events() + note1 = events[0] + tokens = note1["text"].split(" ") + self.assertEqual(len(tokens), 21) + + def test_patient_20_no_phi(self): + """Patient 20's note has no PHI, all labels should be O.""" + events = self.dataset.get_patient("20").get_events() + labels = events[0]["labels"].split(" ") + self.assertTrue(all(lbl == "O" for lbl in labels)) + + def test_patient_60_has_name_labels(self): + """Patient 60's note starts with NAME.""" + events = self.dataset.get_patient("60").get_events() + labels = events[0]["labels"].split(" ") + self.assertEqual(labels[0], "B-NAME") + self.assertEqual(labels[1], "I-NAME") + + def test_patient_60_has_location_label(self): + """Patient 60's note has LOCATION.""" + events = self.dataset.get_patient("60").get_events() + tokens = events[0]["text"].split(" ") + labels = events[0]["labels"].split(" ") + lakewood_idx = tokens.index("Lakewood") + self.assertEqual(labels[lakewood_idx], "B-LOCATION") + self.assertEqual(labels[lakewood_idx + 1], "I-LOCATION") + + def test_patient_60_has_date_label(self): + """Patient 60's note has DATE.""" + events = self.dataset.get_patient("60").get_events() + tokens = events[0]["text"].split(" ") + labels = events[0]["labels"].split(" ") + date_idx = tokens.index("11/05/2097") + self.assertEqual(labels[date_idx], "B-DATE") + + def test_patient_60_has_profession_label(self): + """Patient 60's note has PROFESSION.""" + events = self.dataset.get_patient("60").get_events() + tokens = events[0]["text"].split(" ") + labels = events[0]["labels"].split(" ") + prof_idx = tokens.index("plumber.") + self.assertEqual(labels[prof_idx], "B-PROFESSION") + + def test_stats(self): + self.dataset.stats() + + # -- Task tests -- + + def test_task_sample_count(self): + """4 notes total across 3 patients.""" + self.assertEqual(len(self.samples), 4) + + def test_task_sample_has_text_and_labels(self): + sample = self.samples[0] + self.assertIn("text", sample) + self.assertIn("labels", sample) + + def test_task_text_and_labels_same_length(self): + for sample in self.samples: + tokens = sample["text"].split(" ") + labels = sample["labels"].split(" ") + self.assertEqual(len(tokens), len(labels)) + + def test_task_labels_are_valid_bio(self): + valid = {"O"} + for cat in ("AGE", "CONTACT", "DATE", "ID", "LOCATION", "NAME", "PROFESSION"): + valid.add(f"B-{cat}") + valid.add(f"I-{cat}") + for sample in self.samples: + for label in sample["labels"].split(" "): + self.assertIn(label, valid) + + def test_task_sample_has_patient_id(self): + self.assertIn("patient_id", self.samples[0]) + + +class TestDeIDNERTaskWindowing(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "physionet_deid" + ) + cls.cache_dir = tempfile.TemporaryDirectory() + cls.dataset = PhysioNetDeIDDataset( + root=str(cls.root), cache_dir=cls.cache_dir.name + ) + cls.task = DeIDNERTask(window_size=10, window_overlap=5) + cls.samples = cls.dataset.set_task(cls.task) + + @classmethod + def tearDownClass(cls): + cls.samples.close() + cls.cache_dir.cleanup() + metadata = cls.root / "physionet_deid_metadata.csv" + if metadata.exists(): + metadata.unlink() + + def test_windowing_produces_more_samples(self): + """Windowing should produce more samples than the 4 notes.""" + self.assertGreater(len(self.samples), 4) + + def test_window_size_respected(self): + """Each window should have at most window_size tokens.""" + for sample in self.samples: + tokens = sample["text"].split(" ") + self.assertLessEqual(len(tokens), 10) + + def test_window_text_and_labels_same_length(self): + for sample in self.samples: + tokens = sample["text"].split(" ") + labels = sample["labels"].split(" ") + self.assertEqual(len(tokens), len(labels)) + + def test_window_has_patient_id(self): + self.assertIn("patient_id", self.samples[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_transformer_deid.py b/tests/core/test_transformer_deid.py new file mode 100644 index 000000000..e87a68c41 --- /dev/null +++ b/tests/core/test_transformer_deid.py @@ -0,0 +1,193 @@ +""" +Unit tests for the TransformerDeID model. + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset +from pyhealth.models.transformer_deid import ( + IGNORE_INDEX, + LABEL_VOCAB, + TransformerDeID, + align_labels, +) +from pyhealth.processors.text_processor import TextProcessor + + +def _make_dataset(): + """Create a minimal in-memory dataset matching DeIDNERTask output.""" + samples = [ + { + "patient_id": "p1", + "text": "Patient John Smith was seen", + "labels": "O B-NAME I-NAME O O", + }, + { + "patient_id": "p2", + "text": "Admitted on 01/15/2024 to clinic", + "labels": "O O B-DATE O O", + }, + ] + return create_sample_dataset( + samples=samples, + input_schema={"text": TextProcessor}, + output_schema={"labels": TextProcessor}, + dataset_name="test_deid", + task_name="DeIDNER", + in_memory=True, + ) + + +class TestLabelVocab(unittest.TestCase): + def test_vocab_size(self): + """O + 7 categories * 2 (B/I) = 15.""" + self.assertEqual(len(LABEL_VOCAB), 15) + + def test_o_is_zero(self): + self.assertEqual(LABEL_VOCAB["O"], 0) + + def test_all_categories_present(self): + for cat in ("AGE", "CONTACT", "DATE", "ID", "LOCATION", "NAME", "PROFESSION"): + self.assertIn(f"B-{cat}", LABEL_VOCAB) + self.assertIn(f"I-{cat}", LABEL_VOCAB) + + +class TestAlignLabels(unittest.TestCase): + def test_no_subword_splits(self): + """When every word is a single token, labels pass through.""" + # word_ids: None=CLS, 0, 1, 2, None=SEP + word_ids = [None, 0, 1, 2, None] + word_labels = [0, 11, 12] # O, B-NAME, I-NAME + result = align_labels(word_ids, word_labels) + self.assertEqual(result, [IGNORE_INDEX, 0, 11, 12, IGNORE_INDEX]) + + def test_subword_split(self): + """Non-first subtokens should get IGNORE_INDEX.""" + # "Smith" split into 2 subtokens (word_id=1 twice) + word_ids = [None, 0, 1, 1, 2, None] + word_labels = [0, 12, 0] # O, I-NAME, O + result = align_labels(word_ids, word_labels) + self.assertEqual( + result, + [IGNORE_INDEX, 0, 12, IGNORE_INDEX, 0, IGNORE_INDEX], + ) + + def test_all_special_tokens(self): + """All-None word_ids should produce all IGNORE_INDEX.""" + word_ids = [None, None] + result = align_labels(word_ids, []) + self.assertEqual(result, [IGNORE_INDEX, IGNORE_INDEX]) + + +class TestTransformerDeIDInit(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = TransformerDeID(dataset=cls.dataset) + + @classmethod + def tearDownClass(cls): + cls.dataset.close() + + def test_feature_key(self): + self.assertEqual(self.model.feature_key, "text") + + def test_label_key(self): + self.assertEqual(self.model.label_key, "labels") + + def test_num_labels(self): + self.assertEqual(self.model.num_labels, 15) + + def test_classifier_output_dim(self): + self.assertEqual(self.model.classifier.out_features, 15) + + def test_encoder_hidden_size(self): + """BERT-base has hidden_size=768.""" + self.assertEqual(self.model.encoder.config.hidden_size, 768) + + +class TestTransformerDeIDForward(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = TransformerDeID(dataset=cls.dataset) + cls.model.eval() + # Run a forward pass with raw strings (same format as task output). + with torch.no_grad(): + cls.result = cls.model( + text=[ + "Patient John Smith was seen", + "Admitted on 01/15/2024 to clinic", + ], + labels=[ + "O B-NAME I-NAME O O", + "O O B-DATE O O", + ], + ) + + @classmethod + def tearDownClass(cls): + cls.dataset.close() + + def test_output_has_required_keys(self): + for key in ("loss", "logit", "y_prob", "y_true"): + self.assertIn(key, self.result) + + def test_loss_is_scalar(self): + self.assertEqual(self.result["loss"].dim(), 0) + + def test_logit_shape(self): + """logit should be (batch, seq_len, num_labels).""" + logit = self.result["logit"] + self.assertEqual(logit.shape[0], 2) # batch size + self.assertEqual(logit.shape[2], 15) # num labels + + def test_y_prob_shape_matches_logit(self): + self.assertEqual( + self.result["y_prob"].shape, self.result["logit"].shape + ) + + def test_y_prob_sums_to_one(self): + """Softmax probabilities should sum to ~1 along label dim.""" + sums = self.result["y_prob"].sum(dim=-1) + self.assertTrue(torch.allclose(sums, torch.ones_like(sums), atol=1e-5)) + + def test_backward(self): + """Loss backward should produce gradients.""" + # Need train mode and fresh forward pass for gradients. + self.model.train() + result = self.model( + text=["Patient John Smith was seen"], + labels=["O B-NAME I-NAME O O"], + ) + result["loss"].backward() + has_grad = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_grad) + self.model.eval() + self.model.zero_grad() + + def test_deidentify_returns_string(self): + result = self.model.deidentify("Patient John Smith was seen") + self.assertIsInstance(result, str) + + def test_deidentify_same_word_count(self): + """Output should have same number of words (redacted or not).""" + text = "Patient John Smith was seen" + result = self.model.deidentify(text) + self.assertEqual(len(result.split()), len(text.split())) + + def test_deidentify_custom_redact_marker(self): + result = self.model.deidentify("Patient John", redact="[PHI]") + self.assertNotIn("[REDACTED]", result) + + +if __name__ == "__main__": + unittest.main()