diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index b02439d26..8d9a59d21 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -238,6 +238,7 @@ Available Datasets
datasets/pyhealth.datasets.BMDHSDataset
datasets/pyhealth.datasets.COVID19CXRDataset
datasets/pyhealth.datasets.ChestXray14Dataset
+ datasets/pyhealth.datasets.PhysioNetDeIDDataset
datasets/pyhealth.datasets.TUABDataset
datasets/pyhealth.datasets.TUEVDataset
datasets/pyhealth.datasets.ClinVarDataset
diff --git a/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst b/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst
new file mode 100644
index 000000000..4e04cd629
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst
@@ -0,0 +1,9 @@
+pyhealth.datasets.PhysioNetDeIDDataset
+=======================================
+
+The PhysioNet De-Identification dataset. For more information see `here `_. Access requires PhysioNet credentialing.
+
+.. autoclass:: pyhealth.datasets.PhysioNetDeIDDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7368dec94..166402e86 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -177,6 +177,7 @@ API Reference
models/pyhealth.models.GNN
models/pyhealth.models.Transformer
models/pyhealth.models.TransformersModel
+ models/pyhealth.models.TransformerDeID
models/pyhealth.models.RETAIN
models/pyhealth.models.GAMENet
models/pyhealth.models.GraphCare
diff --git a/docs/api/models/pyhealth.models.TransformerDeID.rst b/docs/api/models/pyhealth.models.TransformerDeID.rst
new file mode 100644
index 000000000..d07aa94aa
--- /dev/null
+++ b/docs/api/models/pyhealth.models.TransformerDeID.rst
@@ -0,0 +1,9 @@
+pyhealth.models.TransformerDeID
+===================================
+
+Transformer-based token classifier for clinical text de-identification.
+
+.. autoclass:: pyhealth.models.TransformerDeID
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 399b8f1aa..23a4e06e5 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -224,6 +224,7 @@ Available Tasks
Sleep Staging v2
Benchmark EHRShot
ChestX-ray14 Binary Classification
+ De-Identification NER
ChestX-ray14 Multilabel Classification
Variant Classification (ClinVar)
Mutation Pathogenicity (COSMIC)
diff --git a/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst b/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst
new file mode 100644
index 000000000..2b7428f6e
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst
@@ -0,0 +1,7 @@
+pyhealth.tasks.DeIDNERTask
+=======================================
+
+.. autoclass:: pyhealth.tasks.DeIDNERTask
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py
new file mode 100644
index 000000000..eaa421bd9
--- /dev/null
+++ b/examples/physionet_deid_ner_transformer_deid.py
@@ -0,0 +1,191 @@
+"""Train and evaluate TransformerDeID on PhysioNet de-identification.
+
+End-to-end example: load data, train BERT-base on token-level NER for
+PHI detection, and report binary (PHI vs non-PHI) precision/recall/F1.
+
+Paper: Johnson et al. "Deidentification of free-text medical records
+ using pre-trained bidirectional transformers." CHIL, 2020.
+
+Script structure follows examples/cardiology_detection_isAR_SparcNet.py.
+
+Hyperparameters follow the paper (Section 3.4):
+ - Learning rate: 5e-5
+ - Batch size: 8
+ - Epochs: 3
+ - Weight decay: 0.01
+
+Ablation results (3 epochs, 80/10/10 patient split, seed=42):
+
+ Config Precision Recall F1
+ BERT, no window 95.1% 70.3% 80.8%
+ BERT, win=64/32 94.1% 69.0% 79.6%
+ BERT, win=100/60 86.9% 75.7% 80.9%
+ BERT, win=200/100 94.7% 69.4% 80.1%
+ RoBERTa, no window 98.1% 64.7% 78.0%
+ RoBERTa, win=100/60 82.6% 68.6% 75.0%
+
+ BERT with window=100/60 achieves the best F1 (80.9%), matching the
+ paper's window configuration. Windowing improves recall by allowing
+ BERT to see tokens beyond the 512 truncation limit. RoBERTa has
+ higher precision but lower recall than BERT across configurations.
+
+Usage:
+ python examples/physionet_deid_ner_transformer_deid.py \
+ --data_root path/to/deidentifiedmedicaltext/1.0
+
+ # With windowing (paper Section 3.3):
+ python examples/physionet_deid_ner_transformer_deid.py \
+ --data_root path/to/data --window_size 100 --window_overlap 60
+
+ # With RoBERTa:
+ python examples/physionet_deid_ner_transformer_deid.py \
+ --data_root path/to/data --model_name roberta-base
+
+Author:
+ Matt McKenna (mtm16@illinois.edu)
+"""
+
+import argparse
+from collections import defaultdict
+
+import numpy as np
+import torch
+from sklearn.metrics import precision_score, recall_score, f1_score
+
+from pyhealth.datasets import PhysioNetDeIDDataset, get_dataloader
+from pyhealth.datasets.splitter import split_by_patient
+from pyhealth.models.transformer_deid import (
+ IGNORE_INDEX,
+ TransformerDeID,
+)
+from pyhealth.tasks import DeIDNERTask
+from pyhealth.trainer import Trainer
+
+
+def compute_metrics(model, dataloader):
+ """Binary PHI vs non-PHI token-level metrics with window merging.
+
+ When windowing is used, multiple windows may cover the same token.
+ We merge by taking the non-O prediction with highest probability
+ (paper Section 3.3). Without windowing, each token appears once
+ so no merging is needed.
+ """
+ # Collect per-token gold labels and prediction probabilities,
+ # keyed by (patient_id, note_id, absolute_token_position).
+ token_gold = {}
+ token_preds = defaultdict(list)
+
+ model.eval()
+ with torch.no_grad():
+ for batch in dataloader:
+ result = model(**batch)
+ probs = result["y_prob"] # (batch, seq_len, num_labels)
+ labels = result["y_true"] # (batch, seq_len)
+ patient_ids = batch["patient_id"]
+ note_ids = batch["note_id"]
+ token_starts = batch["token_start"]
+
+ for i in range(len(patient_ids)):
+ pid = patient_ids[i]
+ nid = note_ids[i]
+ start = int(token_starts[i])
+ word_idx = 0
+ for j in range(labels.shape[1]):
+ if labels[i, j].item() == IGNORE_INDEX:
+ continue
+ key = (pid, nid, start + word_idx)
+ token_gold[key] = labels[i, j].item()
+ token_preds[key].append(probs[i, j].cpu().numpy())
+ word_idx += 1
+
+ # Merge overlapping predictions (paper Section 3.3):
+ # if any window predicts non-O, take the non-O with highest score.
+ all_true, all_pred = [], []
+ for key in sorted(token_gold):
+ all_true.append(token_gold[key])
+ preds = token_preds[key]
+ non_o = [(p, p[1:].max()) for p in preds if np.argmax(p) != 0]
+ if non_o:
+ merged = max(non_o, key=lambda x: x[1])[0]
+ else:
+ merged = np.mean(preds, axis=0)
+ all_pred.append(int(np.argmax(merged)))
+
+ # Convert to binary: O (index 0) = 0, any PHI = 1.
+ true_bin = [0 if t == 0 else 1 for t in all_true]
+ pred_bin = [0 if p == 0 else 1 for p in all_pred]
+ return {
+ "precision": precision_score(true_bin, pred_bin),
+ "recall": recall_score(true_bin, pred_bin),
+ "f1": f1_score(true_bin, pred_bin),
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_root",
+ type=str,
+ required=True,
+ help="Path to deidentifiedmedicaltext/1.0 directory",
+ )
+ parser.add_argument("--model_name", type=str, default="bert-base-uncased")
+ parser.add_argument("--epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=8)
+ parser.add_argument("--lr", type=float, default=5e-5)
+ parser.add_argument("--window_size", type=int, default=None,
+ help="Token window size (default: no windowing)")
+ parser.add_argument("--window_overlap", type=int, default=0)
+ parser.add_argument("--seed", type=int, default=42)
+ args = parser.parse_args()
+
+ # 1. Load dataset and set task.
+ print("Loading dataset...")
+ dataset = PhysioNetDeIDDataset(root=args.data_root)
+ task = DeIDNERTask(
+ window_size=args.window_size,
+ window_overlap=args.window_overlap,
+ )
+ samples = dataset.set_task(task)
+ print(f" Patients: {len(dataset.unique_patient_ids)}, Samples: {len(samples)}")
+
+ # 2. Split by patient (80/10/10) so no patient's notes appear in
+ # both train and test.
+ train_data, val_data, test_data = split_by_patient(
+ samples, [0.8, 0.1, 0.1], seed=args.seed
+ )
+ train_loader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True)
+ val_loader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False)
+ test_loader = get_dataloader(test_data, batch_size=args.batch_size, shuffle=False)
+
+ # 3. Create model.
+ model = TransformerDeID(
+ dataset=samples,
+ model_name=args.model_name,
+ )
+
+ # 4. Train using PyHealth's Trainer.
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ trainer = Trainer(model=model, device=device)
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ epochs=args.epochs,
+ optimizer_class=torch.optim.AdamW,
+ optimizer_params={"lr": args.lr},
+ weight_decay=0.01,
+ monitor="loss",
+ monitor_criterion="min",
+ )
+
+ # 5. Evaluate on test set.
+ print("\n=== Test Set Results (binary PHI vs non-PHI) ===")
+ metrics = compute_metrics(model, test_loader)
+ for k, v in metrics.items():
+ print(f" {k}: {v:.4f}")
+
+ samples.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 54e77670c..50b1b3887 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs):
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
+from .physionet_deid import PhysioNetDeIDDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
diff --git a/pyhealth/datasets/configs/physionet_deid.yaml b/pyhealth/datasets/configs/physionet_deid.yaml
new file mode 100644
index 000000000..2054ab809
--- /dev/null
+++ b/pyhealth/datasets/configs/physionet_deid.yaml
@@ -0,0 +1,10 @@
+version: "1.0"
+tables:
+ physionet_deid:
+ file_path: "physionet_deid_metadata.csv"
+ patient_id: "patient_id"
+ timestamp: null
+ attributes:
+ - "note_id"
+ - "text"
+ - "labels"
diff --git a/pyhealth/datasets/physionet_deid.py b/pyhealth/datasets/physionet_deid.py
new file mode 100644
index 000000000..5bd6acc95
--- /dev/null
+++ b/pyhealth/datasets/physionet_deid.py
@@ -0,0 +1,352 @@
+"""
+PyHealth dataset for the PhysioNet De-Identification dataset.
+
+Dataset link:
+ https://physionet.org/content/deidentifiedmedicaltext/1.0/
+
+Dataset paper: (please cite if you use this dataset)
+ Neamatullah, Ishna, et al. "Automated de-identification of free-text
+ medical records." BMC Medical Informatics and Decision Making 8.1 (2008).
+
+Paper link:
+ https://doi.org/10.1186/1472-6947-8-32
+
+PHI category mapping in classify_phi() inspired by the label groupings
+in the bert-deid reference implementation by Johnson et al.:
+ https://github.com/alistairewj/bert-deid/blob/master/bert_deid/label.py
+
+Task paper:
+ Johnson, Alistair E.W., et al. "Deidentification of free-text medical
+ records using pre-trained bidirectional transformers." Proceedings of
+ the ACM Conference on Health, Inference, and Learning (CHIL), 2020.
+
+Author:
+ Matt McKenna (mtm16@illinois.edu)
+"""
+import logging
+import os
+import re
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import pandas as pd
+
+from pyhealth.datasets import BaseDataset
+
+logger = logging.getLogger(__name__)
+
+# -- Record parsing regexes --
+
+_RECORD_START = re.compile(r"START_OF_RECORD=(\d+)\|\|\|\|(\d+)\|\|\|\|")
+_RECORD_END = re.compile(r"\|\|\|\|END_OF_RECORD")
+_PHI_TAG = re.compile(r"\[\*\*(.+?)\*\*\]", re.DOTALL)
+_PHI_SPLIT = re.compile(r"\[\*\*(?:.+?)\*\*\]", re.DOTALL)
+
+
+def _parse_file(path: Path) -> Dict[Tuple[str, str], str]:
+ """Parse a PhysioNet record file into {(patient_id, note_id): body}.
+
+ Args:
+ path: Path to id.text or id.res file.
+
+ Returns:
+ Dictionary mapping (patient_id, note_id) to note body text.
+ """
+ raw = path.read_text(encoding="utf-8", errors="ignore")
+ out: Dict[Tuple[str, str], str] = {}
+ for m in _RECORD_START.finditer(raw):
+ pid, nid = m.group(1), m.group(2)
+ body_start = m.end()
+ end_m = _RECORD_END.search(raw, body_start)
+ body = raw[body_start : end_m.start() if end_m else len(raw)]
+ out[(pid, nid)] = body.strip()
+ return out
+
+
+def classify_phi(raw: str) -> str:
+ """Map raw [**...**] tag text to one of the 7 PHI categories.
+
+ Args:
+ raw: The text inside a [**...**] tag.
+
+ Returns:
+ One of: AGE, DATE, CONTACT, LOCATION, ID, PROFESSION, NAME.
+ """
+ t = re.sub(r"[^a-z0-9 ]+", " ", raw.strip().lower()).strip()
+
+ if any(k in t for k in ("year old", " yo ", " age ")):
+ return "AGE"
+ if any(k in t for k in ("date", "month", "day", "year", "holiday")):
+ return "DATE"
+ if re.fullmatch(r"[\d]{1,2}[ \-/][\d]{1,2}([ \-/][\d]{2,4})?", t):
+ return "DATE"
+ if re.fullmatch(r"[\d]+", raw.strip()):
+ return "DATE"
+ if any(k in t for k in ("phone", "fax", "email", "pager", "contact")):
+ return "CONTACT"
+ if any(
+ k in t
+ for k in (
+ "hospital",
+ "location",
+ "street",
+ "county",
+ "state",
+ "country",
+ "zip",
+ "address",
+ "ward",
+ "room",
+ )
+ ):
+ return "LOCATION"
+ if any(
+ k in t
+ for k in (
+ "mrn",
+ "medical record",
+ "record number",
+ "ssn",
+ "account",
+ "serial",
+ "unit no",
+ "unit number",
+ "identifier",
+ )
+ ):
+ return "ID"
+ if " id " in f" {t} ":
+ return "ID"
+ if any(
+ k in t
+ for k in (
+ "doctor",
+ " dr ",
+ " md ",
+ "nurse",
+ "attending",
+ "resident",
+ "profession",
+ "service",
+ "provider",
+ )
+ ):
+ return "PROFESSION"
+ if any(
+ k in t
+ for k in (
+ "name",
+ "initial",
+ "alias",
+ "patient",
+ "first name",
+ "last name",
+ )
+ ):
+ return "NAME"
+ return "NAME" # fallback
+
+
+def phi_spans_in_original(
+ orig: str, deid: str
+) -> List[Tuple[int, int, str]]:
+ """Find PHI character spans in orig by anchoring on non-PHI chunks.
+
+ Uses non-PHI text from the de-identified version as anchors to locate
+ where the original PHI text appears in the original note.
+
+ Args:
+ orig: Original note text (with real PHI).
+ deid: De-identified note text (PHI replaced with [**...**] tags).
+
+ Returns:
+ List of (char_start, char_end, phi_category) tuples.
+ """
+ parts = _PHI_SPLIT.split(deid)
+ tags = _PHI_TAG.findall(deid)
+
+ spans: List[Tuple[int, int, str]] = []
+ pos = 0
+
+ for i, tag_inner in enumerate(tags):
+ before = parts[i]
+ if before:
+ idx = orig.find(before, pos)
+ pos = (idx + len(before)) if idx != -1 else (pos + len(before))
+
+ phi_start = pos
+
+ after = parts[i + 1]
+ if after:
+ idx = orig.find(after, phi_start)
+ phi_end = idx if idx != -1 else phi_start
+ else:
+ phi_end = len(orig)
+
+ if phi_end > phi_start:
+ spans.append((phi_start, phi_end, classify_phi(tag_inner)))
+
+ pos = phi_end
+
+ return spans
+
+
+def bio_tag(
+ text: str, spans: List[Tuple[int, int, str]]
+) -> List[Tuple[str, str]]:
+ """Whitespace-tokenize text and assign BIO labels from char-level spans.
+
+ Args:
+ text: Original note text.
+ spans: List of (char_start, char_end, phi_category) tuples.
+
+ Returns:
+ List of (word, label) tuples.
+ """
+ char_label = ["O"] * len(text)
+ for start, end, cat in spans:
+ for i in range(start, min(end, len(text))):
+ char_label[i] = cat
+
+ result: List[Tuple[str, str]] = []
+ for m in re.finditer(r"\S+", text):
+ w_start, w_end = m.start(), m.end()
+ word = m.group()
+ # Collect PHI categories for this token's characters, ignoring O's.
+ # If empty, the token has no PHI and we label it O.
+ cats = [c for c in char_label[w_start:w_end] if c != "O"]
+ if not cats:
+ result.append((word, "O"))
+ continue
+ # Pick the most common category. Handles rare cases where a token
+ # spans two PHI types, e.g. "Smith01/15" -> chars are NAME+DATE,
+ # majority wins (DATE).
+ cat = max(set(cats), key=cats.count)
+ # B = beginning of a new entity, I = continuation of the same one.
+ # Use "I" only if the previous token was the same category,
+ # e.g. "Tom"=B-NAME "Garcia"=I-NAME. Otherwise start a new "B".
+ prev_label = result[-1][1] if result else "O"
+ prefix = (
+ "I"
+ if prev_label not in ("O",) and prev_label.endswith(cat)
+ else "B"
+ )
+ result.append((word, f"{prefix}-{cat}"))
+
+ return result
+
+
+class PhysioNetDeIDDataset(BaseDataset):
+ """Dataset class for the PhysioNet De-Identification dataset.
+
+ This dataset contains 2,434 nursing notes from 163 patients.
+ Each note has original text with PHI (protected health information)
+ and a de-identified version with [**...**] tags marking PHI spans.
+
+ The dataset parses both files to produce token-level BIO labels
+ for 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME,
+ PROFESSION.
+
+ Data access requires PhysioNet credentialing:
+ 1. Create a PhysioNet account at https://physionet.org
+ 2. Complete the required CITI training
+ 3. Sign the data use agreement
+ 4. Download from
+ https://physionet.org/content/deidentifiedmedicaltext/1.0/
+
+ Attributes:
+ root (str): Root directory containing id.text and id.res files.
+ dataset_name (str): Name of the dataset.
+
+ Example::
+ >>> dataset = PhysioNetDeIDDataset(root="./data/physionet_deid")
+ """
+
+ def __init__(
+ self,
+ root: str = ".",
+ config_path: Optional[str] = str(
+ Path(__file__).parent / "configs" / "physionet_deid.yaml"
+ ),
+ **kwargs,
+ ) -> None:
+ """Initializes the PhysioNet De-Identification dataset.
+
+ Args:
+ root: Root directory containing id.text and id.res files.
+ config_path: Path to the configuration file.
+
+ Raises:
+ FileNotFoundError: If id.text or id.res not found in root.
+
+ Example::
+ >>> dataset = PhysioNetDeIDDataset(root="./data")
+ """
+ self._verify_data(root)
+ self._index_data(root)
+
+ super().__init__(
+ root=root,
+ tables=["physionet_deid"],
+ dataset_name="PhysioNetDeID",
+ config_path=config_path,
+ **kwargs,
+ )
+
+ def _verify_data(self, root: str) -> None:
+ """Verify that required data files exist.
+
+ Args:
+ root: Root directory to check.
+
+ Raises:
+ FileNotFoundError: If id.text or id.res is missing.
+ """
+ for fname in ("id.text", "id.res"):
+ path = os.path.join(root, fname)
+ if not os.path.isfile(path):
+ raise FileNotFoundError(
+ f"Required file '{fname}' not found in {root}"
+ )
+
+ def _index_data(self, root: str) -> pd.DataFrame:
+ """Parse id.text and id.res into a CSV for BaseDataset to load.
+
+ Args:
+ root: Root directory containing the data files.
+
+ Returns:
+ DataFrame with columns: patient_id, note_id, text, labels.
+ """
+ root_path = Path(root)
+ orig_records = _parse_file(root_path / "id.text")
+ deid_records = _parse_file(root_path / "id.res")
+
+ rows = []
+ for key in sorted(
+ orig_records, key=lambda k: (int(k[0]), int(k[1]))
+ ):
+ pid, nid = key
+ orig = orig_records[key]
+ # Missing key yields empty string (no deid version).
+ deid = deid_records.get(key, "")
+ spans = phi_spans_in_original(orig, deid)
+ tagged = bio_tag(orig, spans)
+
+ tokens = " ".join(w for w, _ in tagged)
+ labels = " ".join(lbl for _, lbl in tagged)
+
+ rows.append(
+ {
+ "patient_id": pid,
+ "note_id": nid,
+ "text": tokens,
+ "labels": labels,
+ }
+ )
+
+ df = pd.DataFrame(rows)
+ df.to_csv(
+ os.path.join(root, "physionet_deid_metadata.csv"), index=False
+ )
+ return df
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 5233b1726..deabaf95c 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -1,6 +1,7 @@
from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare
from .agent import Agent, AgentLayer
from .base_model import BaseModel
+from .transformer_deid import TransformerDeID
from .biot import BIOT
from .cnn import CNN, CNNLayer
from .concare import ConCare, ConCareLayer
diff --git a/pyhealth/models/transformer_deid.py b/pyhealth/models/transformer_deid.py
new file mode 100644
index 000000000..964942f5a
--- /dev/null
+++ b/pyhealth/models/transformer_deid.py
@@ -0,0 +1,263 @@
+"""
+PyHealth model for transformer-based clinical text de-identification.
+
+Performs token-level NER to detect PHI (protected health information)
+in clinical notes using a pre-trained transformer with a classification
+head.
+
+Paper: Johnson, Alistair E.W., et al. "Deidentification of free-text
+ medical records using pre-trained bidirectional transformers."
+ Proceedings of the ACM Conference on Health, Inference, and
+ Learning (CHIL), 2020.
+
+Paper link:
+ https://doi.org/10.1145/3368555.3384455
+
+Model structure (dropout + linear head) follows PyHealth's
+TransformersModel (pyhealth/models/transformers_model.py), adapted
+for token-level classification instead of sequence-level.
+
+Subword alignment follows the standard HuggingFace token
+classification pattern (see BertForTokenClassification).
+
+Author:
+ Matt McKenna (mtm16@illinois.edu)
+"""
+
+import logging
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+from transformers import AutoModel, AutoTokenizer
+
+from ..datasets import SampleDataset
+from .base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+
+# 7 PHI categories with BIO prefix, plus O for non-PHI.
+LABEL_VOCAB = {
+ "O": 0,
+ "B-AGE": 1, "I-AGE": 2,
+ "B-CONTACT": 3, "I-CONTACT": 4,
+ "B-DATE": 5, "I-DATE": 6,
+ "B-ID": 7, "I-ID": 8,
+ "B-LOCATION": 9, "I-LOCATION": 10,
+ "B-NAME": 11, "I-NAME": 12,
+ "B-PROFESSION": 13, "I-PROFESSION": 14,
+}
+
+# Cross-entropy ignores positions with this index (PyTorch convention).
+IGNORE_INDEX = -100
+
+
+def align_labels(
+ word_ids: List[int | None],
+ word_labels: List[int],
+) -> List[int]:
+ """Align word-level labels to subword tokens.
+
+ BERT/RoBERTa tokenizers split words into subwords. For example,
+ "Smith" might become ["Sm", "##ith"]. This function assigns the
+ word's label to the first subtoken and IGNORE_INDEX to the rest,
+ so the loss function skips non-first subtokens. Special tokens
+ ([CLS], [SEP], padding) have word_id=None and also get
+ IGNORE_INDEX.
+
+ Args:
+ word_ids: Output of tokenizer.word_ids(). None for special
+ tokens, integer word index for real tokens.
+ word_labels: Label index for each word in the original text.
+
+ Returns:
+ List of label indices, one per subtoken. Non-first subtokens
+ and special tokens are set to IGNORE_INDEX (-100).
+ """
+ aligned = []
+ prev_word_id = None
+ for word_id in word_ids:
+ if word_id is None:
+ # Special token ([CLS], [SEP], padding).
+ aligned.append(IGNORE_INDEX)
+ elif word_id != prev_word_id:
+ # First subtoken of a word: use the word's label.
+ aligned.append(word_labels[word_id])
+ else:
+ # Non-first subtoken: ignore during loss computation.
+ aligned.append(IGNORE_INDEX)
+ prev_word_id = word_id
+ return aligned
+
+
+class TransformerDeID(BaseModel):
+ """Transformer-based token classifier for clinical text de-identification.
+
+ Uses a pre-trained transformer encoder with a linear classification
+ head to predict BIO-tagged PHI labels for each token.
+
+ Args:
+ dataset: A SampleDataset from set_task().
+ model_name: HuggingFace model name. Default "bert-base-uncased".
+ max_length: Maximum token sequence length. Default 512.
+ dropout: Dropout rate for the classification head. Default 0.1.
+
+ Examples:
+ >>> from pyhealth.datasets import PhysioNetDeIDDataset
+ >>> from pyhealth.tasks import DeIDNERTask
+ >>> from pyhealth.models import TransformerDeID
+ >>> dataset = PhysioNetDeIDDataset(root="/path/to/data")
+ >>> samples = dataset.set_task(DeIDNERTask())
+ >>> model = TransformerDeID(dataset=samples) # BERT
+ >>> model = TransformerDeID(dataset=samples, model_name="roberta-base")
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ model_name: str = "bert-base-uncased",
+ max_length: int = 512,
+ dropout: float = 0.1,
+ ):
+ super(TransformerDeID, self).__init__(dataset=dataset)
+
+ assert len(self.feature_keys) == 1, (
+ "TransformerDeID expects exactly one input feature (text)."
+ )
+ assert len(self.label_keys) == 1, (
+ "TransformerDeID expects exactly one label key."
+ )
+ self.feature_key = self.feature_keys[0]
+ self.label_key = self.label_keys[0]
+
+ self.model_name = model_name
+ self.max_length = max_length
+ self.label_vocab = LABEL_VOCAB
+ self.num_labels = len(LABEL_VOCAB)
+
+ # add_prefix_space=True is required for RoBERTa when using
+ # is_split_into_words=True in the forward pass.
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_name, add_prefix_space=True
+ )
+ self.encoder = AutoModel.from_pretrained(
+ model_name,
+ hidden_dropout_prob=dropout,
+ attention_probs_dropout_prob=dropout,
+ )
+ hidden_size = self.encoder.config.hidden_size
+ self.dropout = nn.Dropout(dropout)
+ self.classifier = nn.Linear(hidden_size, self.num_labels)
+
+ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ **kwargs: Must contain self.feature_key (list of
+ space-joined token strings) and self.label_key
+ (list of space-joined BIO label strings).
+
+ Returns:
+ Dict with keys: loss, logit, y_prob, y_true.
+ """
+ texts: List[str] = kwargs[self.feature_key]
+ label_strings: List[str] = kwargs[self.label_key]
+
+ # Tokenize with is_split_into_words=True so the tokenizer
+ # knows word boundaries and word_ids() works correctly.
+ words_batch = [t.split(" ") for t in texts]
+ encoding = self.tokenizer(
+ words_batch,
+ is_split_into_words=True,
+ padding=True,
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt",
+ )
+
+ # Convert word-level label strings to indices, then align
+ # to subword tokens. Positions that should be ignored during
+ # loss (special tokens, non-first subtokens, padding) get
+ # IGNORE_INDEX (-100), which cross-entropy skips.
+ aligned_labels = []
+ for i, label_str in enumerate(label_strings):
+ word_labels = [
+ self.label_vocab[lbl] for lbl in label_str.split(" ")
+ ]
+ word_ids = encoding.word_ids(batch_index=i)
+ aligned_labels.append(align_labels(word_ids, word_labels))
+
+ labels = torch.tensor(aligned_labels, dtype=torch.long)
+
+ # Move to device
+ input_ids = encoding["input_ids"].to(self.device)
+ attention_mask = encoding["attention_mask"].to(self.device)
+ labels = labels.to(self.device)
+
+ # Encoder -> dropout -> classifier (per-token logits)
+ hidden_states = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ ).last_hidden_state
+ logits = self.classifier(self.dropout(hidden_states))
+
+ # Token-level cross-entropy, ignoring padded/special positions.
+ # We can't use BaseModel.get_loss_function() because it assumes
+ # one label per sample. Instead we call cross_entropy directly
+ # with ignore_index to skip special tokens and non-first subtokens.
+ # Flatten + ignore_index pattern from HuggingFace's
+ # BertForTokenClassification.forward().
+ loss = nn.functional.cross_entropy(
+ logits.view(-1, self.num_labels),
+ labels.view(-1),
+ ignore_index=IGNORE_INDEX,
+ )
+
+ # Per-token probabilities via softmax.
+ y_prob = torch.softmax(logits, dim=-1)
+
+ return {
+ "loss": loss,
+ "logit": logits,
+ "y_prob": y_prob,
+ "y_true": labels,
+ }
+
+ def deidentify(self, text: str, redact: str = "[REDACTED]") -> str:
+ """Replace PHI in a clinical note with a redaction marker.
+
+ Args:
+ text: Raw clinical note as a string.
+ redact: Replacement string for PHI tokens.
+
+ Returns:
+ The note with PHI tokens replaced.
+
+ Example::
+ >>> model.deidentify("Patient John Smith was seen")
+ 'Patient [REDACTED] [REDACTED] was seen'
+ """
+ words = text.split()
+ # Forward pass with dummy labels (all O) since we only
+ # need predictions, not loss.
+ dummy_labels = " ".join(["O"] * len(words))
+ self.eval()
+ with torch.no_grad():
+ result = self(text=[text], labels=[dummy_labels])
+
+ preds = result["logit"][0].argmax(dim=-1)
+ y_true = result["y_true"][0]
+
+ # Map predictions back to words using the non-ignored positions.
+ word_idx = 0
+ output = []
+ for j in range(len(preds)):
+ if y_true[j].item() == IGNORE_INDEX:
+ continue
+ if preds[j].item() != 0: # non-O = PHI
+ output.append(redact)
+ else:
+ output.append(words[word_idx])
+ word_idx += 1
+
+ return " ".join(output)
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 797988377..a32618f9c 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -12,6 +12,7 @@
from .chestxray14_binary_classification import ChestXray14BinaryClassification
from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification
from .covid19_cxr_classification import COVID19CXRClassification
+from .deid_ner import DeIDNERTask
from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4
from .drug_recommendation import (
DrugRecommendationEICU,
diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py
new file mode 100644
index 000000000..67b212bac
--- /dev/null
+++ b/pyhealth/tasks/deid_ner.py
@@ -0,0 +1,118 @@
+"""
+PyHealth task for NER-based de-identification of clinical text.
+
+Converts PhysioNet De-Identification dataset records into token-level
+BIO-tagged NER samples for PHI detection.
+
+Dataset link:
+ https://physionet.org/content/deidentifiedmedicaltext/1.0/
+
+Task paper: (please cite if you use this task)
+ Johnson, Alistair E.W., et al. "Deidentification of free-text medical
+ records using pre-trained bidirectional transformers." Proceedings of
+ the ACM Conference on Health, Inference, and Learning (CHIL), 2020.
+
+Paper link:
+ https://doi.org/10.1145/3368555.3384455
+
+Author:
+ Matt McKenna (mtm16@illinois.edu)
+"""
+
+from typing import Dict, List, Optional, Type, Union
+
+from pyhealth.data import Event, Patient
+from pyhealth.processors.text_processor import TextProcessor
+from pyhealth.tasks import BaseTask
+
+
+class DeIDNERTask(BaseTask):
+ """Token-level NER task for clinical text de-identification.
+
+ Each sample contains a list of tokens and their BIO labels over
+ 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME,
+ PROFESSION.
+
+ Supports optional overlapping windowing (paper Section 3.3) to
+ handle notes longer than BERT's 512 token limit.
+
+ Args:
+ window_size: If set, split notes into overlapping windows of
+ this many tokens. Default None (no windowing).
+ window_overlap: Number of tokens shared between consecutive
+ windows. Default 0.
+
+ Attributes:
+ task_name (str): The name of the task.
+ input_schema (Dict[str, str]): The schema for the task input.
+ output_schema (Dict[str, str]): The schema for the task output.
+
+ Examples:
+ >>> from pyhealth.datasets import PhysioNetDeIDDataset
+ >>> from pyhealth.tasks import DeIDNERTask
+ >>> dataset = PhysioNetDeIDDataset(root="/path/to/data")
+ >>> task = DeIDNERTask()
+ >>> samples = dataset.set_task(task)
+ >>> task_windowed = DeIDNERTask(window_size=100, window_overlap=60)
+ >>> samples = dataset.set_task(task_windowed)
+ """
+
+ task_name: str = "DeIDNER"
+ input_schema: Dict[str, str] = {"text": "text"}
+ # Labels are kept as a space-joined string. The model splits and
+ # encodes them into label indices.
+ output_schema: Dict[str, Union[str, Type]] = {"labels": TextProcessor}
+
+ def __init__(
+ self,
+ window_size: Optional[int] = None,
+ window_overlap: int = 0,
+ ):
+ self.window_size = window_size
+ self.window_overlap = window_overlap
+
+ def __call__(self, patient: Patient) -> List[Dict]:
+ """Generate NER samples from a patient's clinical notes.
+
+ Args:
+ patient: A Patient object with physionet_deid events.
+
+ Returns:
+ List of dicts, each with 'text' (str) and
+ 'labels' (str) keys. Both are space-joined strings.
+ """
+ events: List[Event] = patient.get_events(
+ event_type="physionet_deid"
+ )
+
+ samples = []
+ for event in events:
+ note_id = event["note_id"]
+ words = event["text"].split(" ")
+ labels = event["labels"].split(" ")
+
+ if self.window_size is None:
+ # No windowing: one sample per note.
+ samples.append({
+ "patient_id": patient.patient_id,
+ "note_id": note_id,
+ "token_start": "0",
+ "text": event["text"],
+ "labels": event["labels"],
+ })
+ else:
+ # Overlapping windows (paper Section 3.3).
+ step = self.window_size - self.window_overlap
+ idx = 0
+ while idx < len(words):
+ end = min(idx + self.window_size, len(words))
+ samples.append({
+ "patient_id": patient.patient_id,
+ "note_id": note_id,
+ "token_start": str(idx),
+ "text": " ".join(words[idx:end]),
+ "labels": " ".join(labels[idx:end]),
+ })
+ idx += step
+
+ return samples
diff --git a/test-resources/core/physionet_deid/id.res b/test-resources/core/physionet_deid/id.res
new file mode 100644
index 000000000..1081b2717
--- /dev/null
+++ b/test-resources/core/physionet_deid/id.res
@@ -0,0 +1,12 @@
+START_OF_RECORD=10||||1||||
+Patient [**First Name 101**] [**Last Name 102**] was admitted on [**Date 103**] to [**Hospital 104**]. She is [**Age 105**] years old. Contact phone [**Phone 106**]. MRN [**Medical Record Number 107**].
+||||END_OF_RECORD
+START_OF_RECORD=10||||2||||
+Seen by [**Doctor First Name 201**] [**Doctor Last Name 202**] in radiology. No acute findings.
+||||END_OF_RECORD
+START_OF_RECORD=20||||1||||
+Assessment unchanged. Vitals stable overnight. Continue current plan.
+||||END_OF_RECORD
+START_OF_RECORD=60||||1||||
+[**First Name 301**] [**Last Name 302**] presented to [**Hospital 303**] on [**Date 304**] for follow-up. He works as a [**Profession 305**]. Patient ID [**Medical Record Number 306**].
+||||END_OF_RECORD
diff --git a/test-resources/core/physionet_deid/id.text b/test-resources/core/physionet_deid/id.text
new file mode 100644
index 000000000..e61e9f09d
--- /dev/null
+++ b/test-resources/core/physionet_deid/id.text
@@ -0,0 +1,12 @@
+START_OF_RECORD=10||||1||||
+Patient Jane Doe was admitted on 03/12/2098 to Springfield General Hospital. She is 72 years old. Contact phone 555-867-5309. MRN 00112233.
+||||END_OF_RECORD
+START_OF_RECORD=10||||2||||
+Seen by Dr. Robert Wells in radiology. No acute findings.
+||||END_OF_RECORD
+START_OF_RECORD=20||||1||||
+Assessment unchanged. Vitals stable overnight. Continue current plan.
+||||END_OF_RECORD
+START_OF_RECORD=60||||1||||
+Tom Garcia presented to Lakewood Clinic on 11/05/2097 for follow-up. He works as a plumber. Patient ID 99887766.
+||||END_OF_RECORD
diff --git a/tests/core/test_physionet_deid.py b/tests/core/test_physionet_deid.py
new file mode 100644
index 000000000..6cfb47c6e
--- /dev/null
+++ b/tests/core/test_physionet_deid.py
@@ -0,0 +1,187 @@
+"""
+Unit tests for the PhysioNetDeIDDataset and DeIDNERTask classes.
+
+Author:
+ Matt McKenna (mtm16@illinois.edu)
+"""
+from pathlib import Path
+import tempfile
+import unittest
+
+from pyhealth.datasets import PhysioNetDeIDDataset
+from pyhealth.tasks import DeIDNERTask
+
+
+class TestPhysioNetDeIDDataset(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.root = (
+ Path(__file__).parent.parent.parent
+ / "test-resources"
+ / "core"
+ / "physionet_deid"
+ )
+ cls.cache_dir = tempfile.TemporaryDirectory()
+ cls.dataset = PhysioNetDeIDDataset(
+ root=str(cls.root), cache_dir=cls.cache_dir.name
+ )
+ cls.task = DeIDNERTask()
+ cls.samples = cls.dataset.set_task(cls.task)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.samples.close()
+ cls.cache_dir.cleanup()
+ metadata = cls.root / "physionet_deid_metadata.csv"
+ if metadata.exists():
+ metadata.unlink()
+
+ def test_num_patients(self):
+ self.assertEqual(len(self.dataset.unique_patient_ids), 3)
+
+ def test_patient_ids(self):
+ ids = set(self.dataset.unique_patient_ids)
+ self.assertEqual(ids, {"10", "20", "60"})
+
+ def test_patient_10_has_two_notes(self):
+ events = self.dataset.get_patient("10").get_events()
+ self.assertEqual(len(events), 2)
+
+ def test_patient_20_has_one_note(self):
+ events = self.dataset.get_patient("20").get_events()
+ self.assertEqual(len(events), 1)
+
+ def test_patient_60_has_one_note(self):
+ events = self.dataset.get_patient("60").get_events()
+ self.assertEqual(len(events), 1)
+
+ def test_patient_10_note1_has_tokens_and_labels(self):
+ events = self.dataset.get_patient("10").get_events()
+ note1 = events[0]
+ self.assertIn("text", note1)
+ self.assertIn("labels", note1)
+
+ def test_patient_10_note1_token_count(self):
+ """Note 1 for patient 10 should have the right number of tokens."""
+ events = self.dataset.get_patient("10").get_events()
+ note1 = events[0]
+ tokens = note1["text"].split(" ")
+ self.assertEqual(len(tokens), 21)
+
+ def test_patient_20_no_phi(self):
+ """Patient 20's note has no PHI, all labels should be O."""
+ events = self.dataset.get_patient("20").get_events()
+ labels = events[0]["labels"].split(" ")
+ self.assertTrue(all(lbl == "O" for lbl in labels))
+
+ def test_patient_60_has_name_labels(self):
+ """Patient 60's note starts with NAME."""
+ events = self.dataset.get_patient("60").get_events()
+ labels = events[0]["labels"].split(" ")
+ self.assertEqual(labels[0], "B-NAME")
+ self.assertEqual(labels[1], "I-NAME")
+
+ def test_patient_60_has_location_label(self):
+ """Patient 60's note has LOCATION."""
+ events = self.dataset.get_patient("60").get_events()
+ tokens = events[0]["text"].split(" ")
+ labels = events[0]["labels"].split(" ")
+ lakewood_idx = tokens.index("Lakewood")
+ self.assertEqual(labels[lakewood_idx], "B-LOCATION")
+ self.assertEqual(labels[lakewood_idx + 1], "I-LOCATION")
+
+ def test_patient_60_has_date_label(self):
+ """Patient 60's note has DATE."""
+ events = self.dataset.get_patient("60").get_events()
+ tokens = events[0]["text"].split(" ")
+ labels = events[0]["labels"].split(" ")
+ date_idx = tokens.index("11/05/2097")
+ self.assertEqual(labels[date_idx], "B-DATE")
+
+ def test_patient_60_has_profession_label(self):
+ """Patient 60's note has PROFESSION."""
+ events = self.dataset.get_patient("60").get_events()
+ tokens = events[0]["text"].split(" ")
+ labels = events[0]["labels"].split(" ")
+ prof_idx = tokens.index("plumber.")
+ self.assertEqual(labels[prof_idx], "B-PROFESSION")
+
+ def test_stats(self):
+ self.dataset.stats()
+
+ # -- Task tests --
+
+ def test_task_sample_count(self):
+ """4 notes total across 3 patients."""
+ self.assertEqual(len(self.samples), 4)
+
+ def test_task_sample_has_text_and_labels(self):
+ sample = self.samples[0]
+ self.assertIn("text", sample)
+ self.assertIn("labels", sample)
+
+ def test_task_text_and_labels_same_length(self):
+ for sample in self.samples:
+ tokens = sample["text"].split(" ")
+ labels = sample["labels"].split(" ")
+ self.assertEqual(len(tokens), len(labels))
+
+ def test_task_labels_are_valid_bio(self):
+ valid = {"O"}
+ for cat in ("AGE", "CONTACT", "DATE", "ID", "LOCATION", "NAME", "PROFESSION"):
+ valid.add(f"B-{cat}")
+ valid.add(f"I-{cat}")
+ for sample in self.samples:
+ for label in sample["labels"].split(" "):
+ self.assertIn(label, valid)
+
+ def test_task_sample_has_patient_id(self):
+ self.assertIn("patient_id", self.samples[0])
+
+
+class TestDeIDNERTaskWindowing(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.root = (
+ Path(__file__).parent.parent.parent
+ / "test-resources"
+ / "core"
+ / "physionet_deid"
+ )
+ cls.cache_dir = tempfile.TemporaryDirectory()
+ cls.dataset = PhysioNetDeIDDataset(
+ root=str(cls.root), cache_dir=cls.cache_dir.name
+ )
+ cls.task = DeIDNERTask(window_size=10, window_overlap=5)
+ cls.samples = cls.dataset.set_task(cls.task)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.samples.close()
+ cls.cache_dir.cleanup()
+ metadata = cls.root / "physionet_deid_metadata.csv"
+ if metadata.exists():
+ metadata.unlink()
+
+ def test_windowing_produces_more_samples(self):
+ """Windowing should produce more samples than the 4 notes."""
+ self.assertGreater(len(self.samples), 4)
+
+ def test_window_size_respected(self):
+ """Each window should have at most window_size tokens."""
+ for sample in self.samples:
+ tokens = sample["text"].split(" ")
+ self.assertLessEqual(len(tokens), 10)
+
+ def test_window_text_and_labels_same_length(self):
+ for sample in self.samples:
+ tokens = sample["text"].split(" ")
+ labels = sample["labels"].split(" ")
+ self.assertEqual(len(tokens), len(labels))
+
+ def test_window_has_patient_id(self):
+ self.assertIn("patient_id", self.samples[0])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_transformer_deid.py b/tests/core/test_transformer_deid.py
new file mode 100644
index 000000000..e87a68c41
--- /dev/null
+++ b/tests/core/test_transformer_deid.py
@@ -0,0 +1,193 @@
+"""
+Unit tests for the TransformerDeID model.
+
+Author:
+ Matt McKenna (mtm16@illinois.edu)
+"""
+
+import unittest
+
+import torch
+
+from pyhealth.datasets import create_sample_dataset
+from pyhealth.models.transformer_deid import (
+ IGNORE_INDEX,
+ LABEL_VOCAB,
+ TransformerDeID,
+ align_labels,
+)
+from pyhealth.processors.text_processor import TextProcessor
+
+
+def _make_dataset():
+ """Create a minimal in-memory dataset matching DeIDNERTask output."""
+ samples = [
+ {
+ "patient_id": "p1",
+ "text": "Patient John Smith was seen",
+ "labels": "O B-NAME I-NAME O O",
+ },
+ {
+ "patient_id": "p2",
+ "text": "Admitted on 01/15/2024 to clinic",
+ "labels": "O O B-DATE O O",
+ },
+ ]
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"text": TextProcessor},
+ output_schema={"labels": TextProcessor},
+ dataset_name="test_deid",
+ task_name="DeIDNER",
+ in_memory=True,
+ )
+
+
+class TestLabelVocab(unittest.TestCase):
+ def test_vocab_size(self):
+ """O + 7 categories * 2 (B/I) = 15."""
+ self.assertEqual(len(LABEL_VOCAB), 15)
+
+ def test_o_is_zero(self):
+ self.assertEqual(LABEL_VOCAB["O"], 0)
+
+ def test_all_categories_present(self):
+ for cat in ("AGE", "CONTACT", "DATE", "ID", "LOCATION", "NAME", "PROFESSION"):
+ self.assertIn(f"B-{cat}", LABEL_VOCAB)
+ self.assertIn(f"I-{cat}", LABEL_VOCAB)
+
+
+class TestAlignLabels(unittest.TestCase):
+ def test_no_subword_splits(self):
+ """When every word is a single token, labels pass through."""
+ # word_ids: None=CLS, 0, 1, 2, None=SEP
+ word_ids = [None, 0, 1, 2, None]
+ word_labels = [0, 11, 12] # O, B-NAME, I-NAME
+ result = align_labels(word_ids, word_labels)
+ self.assertEqual(result, [IGNORE_INDEX, 0, 11, 12, IGNORE_INDEX])
+
+ def test_subword_split(self):
+ """Non-first subtokens should get IGNORE_INDEX."""
+ # "Smith" split into 2 subtokens (word_id=1 twice)
+ word_ids = [None, 0, 1, 1, 2, None]
+ word_labels = [0, 12, 0] # O, I-NAME, O
+ result = align_labels(word_ids, word_labels)
+ self.assertEqual(
+ result,
+ [IGNORE_INDEX, 0, 12, IGNORE_INDEX, 0, IGNORE_INDEX],
+ )
+
+ def test_all_special_tokens(self):
+ """All-None word_ids should produce all IGNORE_INDEX."""
+ word_ids = [None, None]
+ result = align_labels(word_ids, [])
+ self.assertEqual(result, [IGNORE_INDEX, IGNORE_INDEX])
+
+
+class TestTransformerDeIDInit(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ cls.model = TransformerDeID(dataset=cls.dataset)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.dataset.close()
+
+ def test_feature_key(self):
+ self.assertEqual(self.model.feature_key, "text")
+
+ def test_label_key(self):
+ self.assertEqual(self.model.label_key, "labels")
+
+ def test_num_labels(self):
+ self.assertEqual(self.model.num_labels, 15)
+
+ def test_classifier_output_dim(self):
+ self.assertEqual(self.model.classifier.out_features, 15)
+
+ def test_encoder_hidden_size(self):
+ """BERT-base has hidden_size=768."""
+ self.assertEqual(self.model.encoder.config.hidden_size, 768)
+
+
+class TestTransformerDeIDForward(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ cls.model = TransformerDeID(dataset=cls.dataset)
+ cls.model.eval()
+ # Run a forward pass with raw strings (same format as task output).
+ with torch.no_grad():
+ cls.result = cls.model(
+ text=[
+ "Patient John Smith was seen",
+ "Admitted on 01/15/2024 to clinic",
+ ],
+ labels=[
+ "O B-NAME I-NAME O O",
+ "O O B-DATE O O",
+ ],
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.dataset.close()
+
+ def test_output_has_required_keys(self):
+ for key in ("loss", "logit", "y_prob", "y_true"):
+ self.assertIn(key, self.result)
+
+ def test_loss_is_scalar(self):
+ self.assertEqual(self.result["loss"].dim(), 0)
+
+ def test_logit_shape(self):
+ """logit should be (batch, seq_len, num_labels)."""
+ logit = self.result["logit"]
+ self.assertEqual(logit.shape[0], 2) # batch size
+ self.assertEqual(logit.shape[2], 15) # num labels
+
+ def test_y_prob_shape_matches_logit(self):
+ self.assertEqual(
+ self.result["y_prob"].shape, self.result["logit"].shape
+ )
+
+ def test_y_prob_sums_to_one(self):
+ """Softmax probabilities should sum to ~1 along label dim."""
+ sums = self.result["y_prob"].sum(dim=-1)
+ self.assertTrue(torch.allclose(sums, torch.ones_like(sums), atol=1e-5))
+
+ def test_backward(self):
+ """Loss backward should produce gradients."""
+ # Need train mode and fresh forward pass for gradients.
+ self.model.train()
+ result = self.model(
+ text=["Patient John Smith was seen"],
+ labels=["O B-NAME I-NAME O O"],
+ )
+ result["loss"].backward()
+ has_grad = any(
+ p.requires_grad and p.grad is not None
+ for p in self.model.parameters()
+ )
+ self.assertTrue(has_grad)
+ self.model.eval()
+ self.model.zero_grad()
+
+ def test_deidentify_returns_string(self):
+ result = self.model.deidentify("Patient John Smith was seen")
+ self.assertIsInstance(result, str)
+
+ def test_deidentify_same_word_count(self):
+ """Output should have same number of words (redacted or not)."""
+ text = "Patient John Smith was seen"
+ result = self.model.deidentify(text)
+ self.assertEqual(len(result.split()), len(text.split()))
+
+ def test_deidentify_custom_redact_marker(self):
+ result = self.model.deidentify("Patient John", redact="[PHI]")
+ self.assertNotIn("[REDACTED]", result)
+
+
+if __name__ == "__main__":
+ unittest.main()