From 108ae5d7d222e9b5de437ffe9e743c8c8d2dc046 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Thu, 2 Apr 2026 23:27:23 -0700 Subject: [PATCH 01/15] Add PhysioNetDeIDDataset --- pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/configs/physionet_deid.yaml | 10 + pyhealth/datasets/physionet_deid.py | 346 ++++++++++++++++++ test-resources/core/physionet_deid/id.res | 12 + test-resources/core/physionet_deid/id.text | 12 + tests/core/test_physionet_deid.py | 110 ++++++ 6 files changed, 491 insertions(+) create mode 100644 pyhealth/datasets/configs/physionet_deid.yaml create mode 100644 pyhealth/datasets/physionet_deid.py create mode 100644 test-resources/core/physionet_deid/id.res create mode 100644 test-resources/core/physionet_deid/id.text create mode 100644 tests/core/test_physionet_deid.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..50b1b3887 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset +from .physionet_deid import PhysioNetDeIDDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset diff --git a/pyhealth/datasets/configs/physionet_deid.yaml b/pyhealth/datasets/configs/physionet_deid.yaml new file mode 100644 index 000000000..2054ab809 --- /dev/null +++ b/pyhealth/datasets/configs/physionet_deid.yaml @@ -0,0 +1,10 @@ +version: "1.0" +tables: + physionet_deid: + file_path: "physionet_deid_metadata.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "note_id" + - "text" + - "labels" diff --git a/pyhealth/datasets/physionet_deid.py b/pyhealth/datasets/physionet_deid.py new file mode 100644 index 000000000..7e7599d52 --- /dev/null +++ b/pyhealth/datasets/physionet_deid.py @@ -0,0 +1,346 @@ +""" +PyHealth dataset for the PhysioNet De-Identification dataset. + +Dataset link: + https://physionet.org/content/deidentifiedmedicaltext/1.0/ + +Dataset paper: (please cite if you use this dataset) + Neamatullah, Ishna, et al. "Automated de-identification of free-text + medical records." BMC Medical Informatics and Decision Making 8.1 (2008). + +Paper link: + https://doi.org/10.1186/1472-6947-8-32 + +BIO tagging and PHI classification adapted from the bert-deid reference +implementation by Johnson et al.: + https://github.com/alistairewj/bert-deid/blob/master/bert_deid/label.py + +Author: + Matt McKenna (mtm16@illinois.edu) +""" +import logging +import os +import re +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import pandas as pd + +from pyhealth.datasets import BaseDataset + +logger = logging.getLogger(__name__) + +# -- Record parsing regexes -- + +_RECORD_START = re.compile(r"START_OF_RECORD=(\d+)\|\|\|\|(\d+)\|\|\|\|") +_RECORD_END = re.compile(r"\|\|\|\|END_OF_RECORD") +_PHI_TAG = re.compile(r"\[\*\*(.+?)\*\*\]", re.DOTALL) +_PHI_SPLIT = re.compile(r"\[\*\*(?:.+?)\*\*\]", re.DOTALL) + + +def _parse_file(path: Path) -> Dict[Tuple[str, str], str]: + """Parse a PhysioNet record file into {(patient_id, note_id): body}. + + Args: + path: Path to id.text or id.res file. + + Returns: + Dictionary mapping (patient_id, note_id) to note body text. + """ + raw = path.read_text(encoding="utf-8", errors="ignore") + out: Dict[Tuple[str, str], str] = {} + for m in _RECORD_START.finditer(raw): + pid, nid = m.group(1), m.group(2) + body_start = m.end() + end_m = _RECORD_END.search(raw, body_start) + body = raw[body_start : end_m.start() if end_m else len(raw)] + out[(pid, nid)] = body.strip() + return out + + +def classify_phi(raw: str) -> str: + """Map raw [**...**] tag text to one of the 7 PHI categories. + + Args: + raw: The text inside a [**...**] tag. + + Returns: + One of: AGE, DATE, CONTACT, LOCATION, ID, PROFESSION, NAME. + """ + t = re.sub(r"[^a-z0-9 ]+", " ", raw.strip().lower()).strip() + + if any(k in t for k in ("year old", " yo ", " age ")): + return "AGE" + if any(k in t for k in ("date", "month", "day", "year", "holiday")): + return "DATE" + if re.fullmatch(r"[\d]{1,2}[ \-/][\d]{1,2}([ \-/][\d]{2,4})?", t): + return "DATE" + if re.fullmatch(r"[\d]+", raw.strip()): + return "DATE" + if any(k in t for k in ("phone", "fax", "email", "pager", "contact")): + return "CONTACT" + if any( + k in t + for k in ( + "hospital", + "location", + "street", + "county", + "state", + "country", + "zip", + "address", + "ward", + "room", + ) + ): + return "LOCATION" + if any( + k in t + for k in ( + "mrn", + "medical record", + "record number", + "ssn", + "account", + "serial", + "unit no", + "unit number", + "identifier", + ) + ): + return "ID" + if " id " in f" {t} ": + return "ID" + if any( + k in t + for k in ( + "doctor", + " dr ", + " md ", + "nurse", + "attending", + "resident", + "profession", + "service", + "provider", + ) + ): + return "PROFESSION" + if any( + k in t + for k in ( + "name", + "initial", + "alias", + "patient", + "first name", + "last name", + ) + ): + return "NAME" + return "NAME" # fallback + + +def phi_spans_in_original( + orig: str, deid: str +) -> List[Tuple[int, int, str]]: + """Find PHI character spans in orig by anchoring on non-PHI chunks. + + Uses non-PHI text from the de-identified version as anchors to locate + where the original PHI text appears in the original note. + + Args: + orig: Original note text (with real PHI). + deid: De-identified note text (PHI replaced with [**...**] tags). + + Returns: + List of (char_start, char_end, phi_category) tuples. + """ + parts = _PHI_SPLIT.split(deid) + tags = _PHI_TAG.findall(deid) + + spans: List[Tuple[int, int, str]] = [] + pos = 0 + + for i, tag_inner in enumerate(tags): + before = parts[i] + if before: + idx = orig.find(before, pos) + pos = (idx + len(before)) if idx != -1 else (pos + len(before)) + + phi_start = pos + + after = parts[i + 1] + if after: + idx = orig.find(after, phi_start) + phi_end = idx if idx != -1 else phi_start + else: + phi_end = len(orig) + + if phi_end > phi_start: + spans.append((phi_start, phi_end, classify_phi(tag_inner))) + + pos = phi_end + + return spans + + +def bio_tag( + text: str, spans: List[Tuple[int, int, str]] +) -> List[Tuple[str, str]]: + """Whitespace-tokenize text and assign BIO labels from char-level spans. + + Args: + text: Original note text. + spans: List of (char_start, char_end, phi_category) tuples. + + Returns: + List of (word, label) tuples. + """ + char_label = ["O"] * len(text) + for start, end, cat in spans: + for i in range(start, min(end, len(text))): + char_label[i] = cat + + result: List[Tuple[str, str]] = [] + for m in re.finditer(r"\S+", text): + w_start, w_end = m.start(), m.end() + word = m.group() + # Collect PHI categories for this token's characters, ignoring O's. + # If empty, the token has no PHI and we label it O. + cats = [c for c in char_label[w_start:w_end] if c != "O"] + if not cats: + result.append((word, "O")) + continue + # Pick the most common category. Handles rare cases where a token + # spans two PHI types, e.g. "Smith01/15" -> chars are NAME+DATE, + # majority wins (DATE). + cat = max(set(cats), key=cats.count) + # B = beginning of a new entity, I = continuation of the same one. + # Use "I" only if the previous token was the same category, + # e.g. "Tom"=B-NAME "Garcia"=I-NAME. Otherwise start a new "B". + prev_label = result[-1][1] if result else "O" + prefix = ( + "I" + if prev_label not in ("O",) and prev_label.endswith(cat) + else "B" + ) + result.append((word, f"{prefix}-{cat}")) + + return result + + +class PhysioNetDeIDDataset(BaseDataset): + """Dataset class for the PhysioNet De-Identification dataset. + + This dataset contains 2,434 nursing notes from 163 patients. + Each note has original text with PHI (protected health information) + and a de-identified version with [**...**] tags marking PHI spans. + + The dataset parses both files to produce token-level BIO labels + for 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME, + PROFESSION. + + Data access requires PhysioNet credentialing: + 1. Create a PhysioNet account at https://physionet.org + 2. Complete the required CITI training + 3. Sign the data use agreement + 4. Download from + https://physionet.org/content/deidentifiedmedicaltext/1.0/ + + Attributes: + root (str): Root directory containing id.text and id.res files. + dataset_name (str): Name of the dataset. + + Example:: + >>> dataset = PhysioNetDeIDDataset(root="./data/physionet_deid") + """ + + def __init__( + self, + root: str = ".", + config_path: Optional[str] = str( + Path(__file__).parent / "configs" / "physionet_deid.yaml" + ), + **kwargs, + ) -> None: + """Initializes the PhysioNet De-Identification dataset. + + Args: + root: Root directory containing id.text and id.res files. + config_path: Path to the configuration file. + + Raises: + FileNotFoundError: If id.text or id.res not found in root. + + Example:: + >>> dataset = PhysioNetDeIDDataset(root="./data") + """ + self._verify_data(root) + self._index_data(root) + + super().__init__( + root=root, + tables=["physionet_deid"], + dataset_name="PhysioNetDeID", + config_path=config_path, + **kwargs, + ) + + def _verify_data(self, root: str) -> None: + """Verify that required data files exist. + + Args: + root: Root directory to check. + + Raises: + FileNotFoundError: If id.text or id.res is missing. + """ + for fname in ("id.text", "id.res"): + path = os.path.join(root, fname) + if not os.path.isfile(path): + raise FileNotFoundError( + f"Required file '{fname}' not found in {root}" + ) + + def _index_data(self, root: str) -> pd.DataFrame: + """Parse id.text and id.res into a CSV for BaseDataset to load. + + Args: + root: Root directory containing the data files. + + Returns: + DataFrame with columns: patient_id, note_id, text, labels. + """ + root_path = Path(root) + orig_records = _parse_file(root_path / "id.text") + deid_records = _parse_file(root_path / "id.res") + + rows = [] + for key in sorted( + orig_records, key=lambda k: (int(k[0]), int(k[1])) + ): + pid, nid = key + orig = orig_records[key] + deid = deid_records.get(key, "") # defensive: missing key yields empty string + spans = phi_spans_in_original(orig, deid) + tagged = bio_tag(orig, spans) + + tokens = " ".join(w for w, _ in tagged) + labels = " ".join(lbl for _, lbl in tagged) + + rows.append( + { + "patient_id": pid, + "note_id": nid, + "text": tokens, + "labels": labels, + } + ) + + df = pd.DataFrame(rows) + df.to_csv( + os.path.join(root, "physionet_deid_metadata.csv"), index=False + ) + return df diff --git a/test-resources/core/physionet_deid/id.res b/test-resources/core/physionet_deid/id.res new file mode 100644 index 000000000..1081b2717 --- /dev/null +++ b/test-resources/core/physionet_deid/id.res @@ -0,0 +1,12 @@ +START_OF_RECORD=10||||1|||| +Patient [**First Name 101**] [**Last Name 102**] was admitted on [**Date 103**] to [**Hospital 104**]. She is [**Age 105**] years old. Contact phone [**Phone 106**]. MRN [**Medical Record Number 107**]. +||||END_OF_RECORD +START_OF_RECORD=10||||2|||| +Seen by [**Doctor First Name 201**] [**Doctor Last Name 202**] in radiology. No acute findings. +||||END_OF_RECORD +START_OF_RECORD=20||||1|||| +Assessment unchanged. Vitals stable overnight. Continue current plan. +||||END_OF_RECORD +START_OF_RECORD=60||||1|||| +[**First Name 301**] [**Last Name 302**] presented to [**Hospital 303**] on [**Date 304**] for follow-up. He works as a [**Profession 305**]. Patient ID [**Medical Record Number 306**]. +||||END_OF_RECORD diff --git a/test-resources/core/physionet_deid/id.text b/test-resources/core/physionet_deid/id.text new file mode 100644 index 000000000..e61e9f09d --- /dev/null +++ b/test-resources/core/physionet_deid/id.text @@ -0,0 +1,12 @@ +START_OF_RECORD=10||||1|||| +Patient Jane Doe was admitted on 03/12/2098 to Springfield General Hospital. She is 72 years old. Contact phone 555-867-5309. MRN 00112233. +||||END_OF_RECORD +START_OF_RECORD=10||||2|||| +Seen by Dr. Robert Wells in radiology. No acute findings. +||||END_OF_RECORD +START_OF_RECORD=20||||1|||| +Assessment unchanged. Vitals stable overnight. Continue current plan. +||||END_OF_RECORD +START_OF_RECORD=60||||1|||| +Tom Garcia presented to Lakewood Clinic on 11/05/2097 for follow-up. He works as a plumber. Patient ID 99887766. +||||END_OF_RECORD diff --git a/tests/core/test_physionet_deid.py b/tests/core/test_physionet_deid.py new file mode 100644 index 000000000..632a52a88 --- /dev/null +++ b/tests/core/test_physionet_deid.py @@ -0,0 +1,110 @@ +""" +Unit tests for the PhysioNetDeIDDataset class. + +Author: + Matt McKenna (mtm16@illinois.edu) +""" +from pathlib import Path +import tempfile +import unittest + +from pyhealth.datasets import PhysioNetDeIDDataset + + +class TestPhysioNetDeIDDataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "physionet_deid" + ) + cls.cache_dir = tempfile.TemporaryDirectory() + cls.dataset = PhysioNetDeIDDataset( + root=str(cls.root), cache_dir=cls.cache_dir.name + ) + + @classmethod + def tearDownClass(cls): + cls.cache_dir.cleanup() + metadata = cls.root / "physionet_deid_metadata.csv" + if metadata.exists(): + metadata.unlink() + + def test_num_patients(self): + self.assertEqual(len(self.dataset.unique_patient_ids), 3) + + def test_patient_ids(self): + ids = set(self.dataset.unique_patient_ids) + self.assertEqual(ids, {"10", "20", "60"}) + + def test_patient_10_has_two_notes(self): + events = self.dataset.get_patient("10").get_events() + self.assertEqual(len(events), 2) + + def test_patient_20_has_one_note(self): + events = self.dataset.get_patient("20").get_events() + self.assertEqual(len(events), 1) + + def test_patient_60_has_one_note(self): + events = self.dataset.get_patient("60").get_events() + self.assertEqual(len(events), 1) + + def test_patient_10_note1_has_tokens_and_labels(self): + events = self.dataset.get_patient("10").get_events() + note1 = events[0] + self.assertIn("text", note1) + self.assertIn("labels", note1) + + def test_patient_10_note1_token_count(self): + """Note 1 for patient 10 should have the right number of tokens.""" + events = self.dataset.get_patient("10").get_events() + note1 = events[0] + tokens = note1["text"].split(" ") + self.assertEqual(len(tokens), 21) + + def test_patient_20_no_phi(self): + """Patient 20's note has no PHI, all labels should be O.""" + events = self.dataset.get_patient("20").get_events() + labels = events[0]["labels"].split(" ") + self.assertTrue(all(lbl == "O" for lbl in labels)) + + def test_patient_60_has_name_labels(self): + """Patient 60's note starts with NAME.""" + events = self.dataset.get_patient("60").get_events() + labels = events[0]["labels"].split(" ") + self.assertEqual(labels[0], "B-NAME") + self.assertEqual(labels[1], "I-NAME") + + def test_patient_60_has_location_label(self): + """Patient 60's note has LOCATION.""" + events = self.dataset.get_patient("60").get_events() + tokens = events[0]["text"].split(" ") + labels = events[0]["labels"].split(" ") + lakewood_idx = tokens.index("Lakewood") + self.assertEqual(labels[lakewood_idx], "B-LOCATION") + self.assertEqual(labels[lakewood_idx + 1], "I-LOCATION") + + def test_patient_60_has_date_label(self): + """Patient 60's note has DATE.""" + events = self.dataset.get_patient("60").get_events() + tokens = events[0]["text"].split(" ") + labels = events[0]["labels"].split(" ") + date_idx = tokens.index("11/05/2097") + self.assertEqual(labels[date_idx], "B-DATE") + + def test_patient_60_has_profession_label(self): + """Patient 60's note has PROFESSION.""" + events = self.dataset.get_patient("60").get_events() + tokens = events[0]["text"].split(" ") + labels = events[0]["labels"].split(" ") + prof_idx = tokens.index("plumber.") + self.assertEqual(labels[prof_idx], "B-PROFESSION") + + def test_stats(self): + self.dataset.stats() + + +if __name__ == "__main__": + unittest.main() From 3fc1d591bfec46f3ada63b3854a6fd757cbc3f0c Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Thu, 9 Apr 2026 16:34:27 -0700 Subject: [PATCH 02/15] add DeIDNERTask for NER-based de-identification and update tests --- pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/deid_ner.py | 79 +++++++++++++++++++++++++++++++ tests/core/test_physionet_deid.py | 32 ++++++++++++- 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 pyhealth/tasks/deid_ner.py diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..a32618f9c 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -12,6 +12,7 @@ from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification +from .deid_ner import DeIDNERTask from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 from .drug_recommendation import ( DrugRecommendationEICU, diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py new file mode 100644 index 000000000..58fe21867 --- /dev/null +++ b/pyhealth/tasks/deid_ner.py @@ -0,0 +1,79 @@ +""" +PyHealth task for NER-based de-identification of clinical text. + +Converts PhysioNet De-Identification dataset records into token-level +BIO-tagged NER samples for PHI detection. + +Dataset link: + https://physionet.org/content/deidentifiedmedicaltext/1.0/ + +Task paper: (please cite if you use this task) + Johnson, Alistair E.W., et al. "Deidentification of free-text medical + records using pre-trained bidirectional transformers." Proceedings of + the ACM Conference on Health, Inference, and Learning (CHIL), 2020. + +Paper link: + https://doi.org/10.1145/3368555.3384455 + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import logging +from typing import Dict, List, Type, Union + +from pyhealth.data import Event, Patient +from pyhealth.processors.text_processor import TextProcessor +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + + +class DeIDNERTask(BaseTask): + """Token-level NER task for clinical text de-identification. + + Each sample contains a list of tokens and their BIO labels over + 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME, + PROFESSION. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import PhysioNetDeIDDataset + >>> from pyhealth.tasks import DeIDNERTask + >>> dataset = PhysioNetDeIDDataset(root="/path/to/data") + >>> task = DeIDNERTask() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "DeIDNER" + input_schema: Dict[str, str] = {"text": "text"} + # Labels are kept as a space-joined string. The model splits and + # encodes them into label indices. + output_schema: Dict[str, Union[str, Type]] = {"labels": TextProcessor} + + def __call__(self, patient: Patient) -> List[Dict]: + """Generate NER samples from a patient's clinical notes. + + Args: + patient: A Patient object with physionet_deid events. + + Returns: + List of dicts, each with 'text' (str) and + 'labels' (str) keys. Both are space-joined strings. + """ + events: List[Event] = patient.get_events( + event_type="physionet_deid" + ) + + samples = [] + for event in events: + samples.append({ + "text": event["text"], + "labels": event["labels"], + }) + + return samples diff --git a/tests/core/test_physionet_deid.py b/tests/core/test_physionet_deid.py index 632a52a88..12de4b7a6 100644 --- a/tests/core/test_physionet_deid.py +++ b/tests/core/test_physionet_deid.py @@ -1,5 +1,5 @@ """ -Unit tests for the PhysioNetDeIDDataset class. +Unit tests for the PhysioNetDeIDDataset and DeIDNERTask classes. Author: Matt McKenna (mtm16@illinois.edu) @@ -9,6 +9,7 @@ import unittest from pyhealth.datasets import PhysioNetDeIDDataset +from pyhealth.tasks import DeIDNERTask class TestPhysioNetDeIDDataset(unittest.TestCase): @@ -24,9 +25,12 @@ def setUpClass(cls): cls.dataset = PhysioNetDeIDDataset( root=str(cls.root), cache_dir=cls.cache_dir.name ) + cls.task = DeIDNERTask() + cls.samples = cls.dataset.set_task(cls.task) @classmethod def tearDownClass(cls): + cls.samples.close() cls.cache_dir.cleanup() metadata = cls.root / "physionet_deid_metadata.csv" if metadata.exists(): @@ -105,6 +109,32 @@ def test_patient_60_has_profession_label(self): def test_stats(self): self.dataset.stats() + # -- Task tests -- + + def test_task_sample_count(self): + """4 notes total across 3 patients.""" + self.assertEqual(len(self.samples), 4) + + def test_task_sample_has_text_and_labels(self): + sample = self.samples[0] + self.assertIn("text", sample) + self.assertIn("labels", sample) + + def test_task_text_and_labels_same_length(self): + for sample in self.samples: + tokens = sample["text"].split(" ") + labels = sample["labels"].split(" ") + self.assertEqual(len(tokens), len(labels)) + + def test_task_labels_are_valid_bio(self): + valid = {"O"} + for cat in ("AGE", "CONTACT", "DATE", "ID", "LOCATION", "NAME", "PROFESSION"): + valid.add(f"B-{cat}") + valid.add(f"I-{cat}") + for sample in self.samples: + for label in sample["labels"].split(" "): + self.assertIn(label, valid) + if __name__ == "__main__": unittest.main() From 40f018c737081e6b83e3b86da24d6abb1240fe16 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Fri, 10 Apr 2026 15:10:14 -0700 Subject: [PATCH 03/15] Add BertDeID model for BERT-based clinical text de-identification and corresponding unit tests --- pyhealth/models/__init__.py | 1 + pyhealth/models/bert_deid.py | 158 +++++++++++++++++++++++++++++++++++ tests/core/test_bert_deid.py | 113 +++++++++++++++++++++++++ 3 files changed, 272 insertions(+) create mode 100644 pyhealth/models/bert_deid.py create mode 100644 tests/core/test_bert_deid.py diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..029dee44e 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,6 +1,7 @@ from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare from .agent import Agent, AgentLayer from .base_model import BaseModel +from .bert_deid import BertDeID from .biot import BIOT from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer diff --git a/pyhealth/models/bert_deid.py b/pyhealth/models/bert_deid.py new file mode 100644 index 000000000..7e0a1370f --- /dev/null +++ b/pyhealth/models/bert_deid.py @@ -0,0 +1,158 @@ +""" +PyHealth model for BERT-based clinical text de-identification. + +Performs token-level NER to detect PHI (protected health information) +in clinical notes using a pre-trained transformer with a classification +head. + +Paper: Johnson, Alistair E.W., et al. "Deidentification of free-text + medical records using pre-trained bidirectional transformers." + Proceedings of the ACM Conference on Health, Inference, and + Learning (CHIL), 2020. + +Paper link: + https://doi.org/10.1145/3368555.3384455 + +Model structure (dropout + linear head) follows PyHealth's +TransformersModel (pyhealth/models/transformers_model.py), adapted +for token-level classification instead of sequence-level. + +Subword alignment logic adapted from the reference implementation: + https://github.com/alistairewj/bert-deid + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import logging +from typing import Dict, List + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer + +from ..datasets import SampleDataset +from .base_model import BaseModel + +logger = logging.getLogger(__name__) + +# 7 PHI categories with BIO prefix, plus O for non-PHI. +LABEL_VOCAB = { + "O": 0, + "B-AGE": 1, "I-AGE": 2, + "B-CONTACT": 3, "I-CONTACT": 4, + "B-DATE": 5, "I-DATE": 6, + "B-ID": 7, "I-ID": 8, + "B-LOCATION": 9, "I-LOCATION": 10, + "B-NAME": 11, "I-NAME": 12, + "B-PROFESSION": 13, "I-PROFESSION": 14, +} + +# Cross-entropy ignores positions with this index (PyTorch convention). +IGNORE_INDEX = -100 + + +def align_labels( + word_ids: List[int | None], + word_labels: List[int], +) -> List[int]: + """Align word-level labels to subword tokens. + + BERT/RoBERTa tokenizers split words into subwords. For example, + "Smith" might become ["Sm", "##ith"]. This function assigns the + word's label to the first subtoken and IGNORE_INDEX to the rest, + so the loss function skips non-first subtokens. Special tokens + ([CLS], [SEP], padding) have word_id=None and also get + IGNORE_INDEX. + + Args: + word_ids: Output of tokenizer.word_ids(). None for special + tokens, integer word index for real tokens. + word_labels: Label index for each word in the original text. + + Returns: + List of label indices, one per subtoken. Non-first subtokens + and special tokens are set to IGNORE_INDEX (-100). + """ + aligned = [] + prev_word_id = None + for word_id in word_ids: + if word_id is None: + # Special token ([CLS], [SEP], padding). + aligned.append(IGNORE_INDEX) + elif word_id != prev_word_id: + # First subtoken of a word: use the word's label. + aligned.append(word_labels[word_id]) + else: + # Non-first subtoken: ignore during loss computation. + aligned.append(IGNORE_INDEX) + prev_word_id = word_id + return aligned + + +class BertDeID(BaseModel): + """BERT-based token classifier for clinical text de-identification. + + Uses a pre-trained transformer encoder with a linear classification + head to predict BIO-tagged PHI labels for each token. + + Args: + dataset: A SampleDataset from set_task(). + model_name: HuggingFace model name. Default "bert-base-uncased". + max_length: Maximum token sequence length. Default 512. + dropout: Dropout rate for the classification head. Default 0.1. + + Examples: + >>> from pyhealth.datasets import PhysioNetDeIDDataset + >>> from pyhealth.tasks import DeIDNERTask + >>> from pyhealth.models import BertDeID + >>> dataset = PhysioNetDeIDDataset(root="/path/to/data") + >>> samples = dataset.set_task(DeIDNERTask()) + >>> model = BertDeID(dataset=samples) + """ + + def __init__( + self, + dataset: SampleDataset, + model_name: str = "bert-base-uncased", + max_length: int = 512, + dropout: float = 0.1, + ): + super(BertDeID, self).__init__(dataset=dataset) + + assert len(self.feature_keys) == 1, ( + "BertDeID expects exactly one input feature (text)." + ) + assert len(self.label_keys) == 1, ( + "BertDeID expects exactly one label key." + ) + self.feature_key = self.feature_keys[0] + self.label_key = self.label_keys[0] + + self.model_name = model_name + self.max_length = max_length + self.label_vocab = LABEL_VOCAB + self.num_labels = len(LABEL_VOCAB) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.encoder = AutoModel.from_pretrained( + model_name, + hidden_dropout_prob=dropout, + attention_probs_dropout_prob=dropout, + ) + hidden_size = self.encoder.config.hidden_size + self.dropout = nn.Dropout(dropout) + self.classifier = nn.Linear(hidden_size, self.num_labels) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass. + + Args: + **kwargs: Must contain self.feature_key (list of + space-joined token strings) and self.label_key + (list of space-joined BIO label strings). + + Returns: + Dict with keys: loss, logit, y_prob, y_true. + """ + raise NotImplementedError("TODO") diff --git a/tests/core/test_bert_deid.py b/tests/core/test_bert_deid.py new file mode 100644 index 000000000..dbc9efce8 --- /dev/null +++ b/tests/core/test_bert_deid.py @@ -0,0 +1,113 @@ +""" +Unit tests for the BertDeID model. + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import unittest + +from pyhealth.datasets import create_sample_dataset +from pyhealth.models.bert_deid import ( + IGNORE_INDEX, + LABEL_VOCAB, + BertDeID, + align_labels, +) +from pyhealth.processors.text_processor import TextProcessor + + +def _make_dataset(): + """Create a minimal in-memory dataset matching DeIDNERTask output.""" + samples = [ + { + "patient_id": "p1", + "text": "Patient John Smith was seen", + "labels": "O B-NAME I-NAME O O", + }, + { + "patient_id": "p2", + "text": "Admitted on 01/15/2024 to clinic", + "labels": "O O B-DATE O O", + }, + ] + return create_sample_dataset( + samples=samples, + input_schema={"text": TextProcessor}, + output_schema={"labels": TextProcessor}, + dataset_name="test_deid", + task_name="DeIDNER", + in_memory=True, + ) + + +class TestLabelVocab(unittest.TestCase): + def test_vocab_size(self): + """O + 7 categories * 2 (B/I) = 15.""" + self.assertEqual(len(LABEL_VOCAB), 15) + + def test_o_is_zero(self): + self.assertEqual(LABEL_VOCAB["O"], 0) + + def test_all_categories_present(self): + for cat in ("AGE", "CONTACT", "DATE", "ID", "LOCATION", "NAME", "PROFESSION"): + self.assertIn(f"B-{cat}", LABEL_VOCAB) + self.assertIn(f"I-{cat}", LABEL_VOCAB) + + +class TestAlignLabels(unittest.TestCase): + def test_no_subword_splits(self): + """When every word is a single token, labels pass through.""" + # word_ids: None=CLS, 0, 1, 2, None=SEP + word_ids = [None, 0, 1, 2, None] + word_labels = [0, 11, 12] # O, B-NAME, I-NAME + result = align_labels(word_ids, word_labels) + self.assertEqual(result, [IGNORE_INDEX, 0, 11, 12, IGNORE_INDEX]) + + def test_subword_split(self): + """Non-first subtokens should get IGNORE_INDEX.""" + # "Smith" split into 2 subtokens (word_id=1 twice) + word_ids = [None, 0, 1, 1, 2, None] + word_labels = [0, 12, 0] # O, I-NAME, O + result = align_labels(word_ids, word_labels) + self.assertEqual( + result, + [IGNORE_INDEX, 0, 12, IGNORE_INDEX, 0, IGNORE_INDEX], + ) + + def test_all_special_tokens(self): + """All-None word_ids should produce all IGNORE_INDEX.""" + word_ids = [None, None] + result = align_labels(word_ids, []) + self.assertEqual(result, [IGNORE_INDEX, IGNORE_INDEX]) + + +class TestBertDeIDInit(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = BertDeID(dataset=cls.dataset) + + @classmethod + def tearDownClass(cls): + cls.dataset.close() + + def test_feature_key(self): + self.assertEqual(self.model.feature_key, "text") + + def test_label_key(self): + self.assertEqual(self.model.label_key, "labels") + + def test_num_labels(self): + self.assertEqual(self.model.num_labels, 15) + + def test_classifier_output_dim(self): + self.assertEqual(self.model.classifier.out_features, 15) + + def test_encoder_hidden_size(self): + """BERT-base has hidden_size=768.""" + self.assertEqual(self.model.encoder.config.hidden_size, 768) + + +if __name__ == "__main__": + unittest.main() From a2e9b3e170d5945ecd2907ee9d87523728ee3964 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Fri, 10 Apr 2026 15:11:53 -0700 Subject: [PATCH 04/15] Rename BertDeID to TransformerDeID and update related documentation and tests --- pyhealth/models/__init__.py | 2 +- pyhealth/models/bert_deid.py | 17 +++++++++-------- tests/core/test_bert_deid.py | 8 ++++---- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 029dee44e..6b660a829 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,7 +1,7 @@ from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare from .agent import Agent, AgentLayer from .base_model import BaseModel -from .bert_deid import BertDeID +from .bert_deid import TransformerDeID from .biot import BIOT from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer diff --git a/pyhealth/models/bert_deid.py b/pyhealth/models/bert_deid.py index 7e0a1370f..bb5950671 100644 --- a/pyhealth/models/bert_deid.py +++ b/pyhealth/models/bert_deid.py @@ -1,5 +1,5 @@ """ -PyHealth model for BERT-based clinical text de-identification. +PyHealth model for transformer-based clinical text de-identification. Performs token-level NER to detect PHI (protected health information) in clinical notes using a pre-trained transformer with a classification @@ -90,8 +90,8 @@ def align_labels( return aligned -class BertDeID(BaseModel): - """BERT-based token classifier for clinical text de-identification. +class TransformerDeID(BaseModel): + """Transformer-based token classifier for clinical text de-identification. Uses a pre-trained transformer encoder with a linear classification head to predict BIO-tagged PHI labels for each token. @@ -105,10 +105,11 @@ class BertDeID(BaseModel): Examples: >>> from pyhealth.datasets import PhysioNetDeIDDataset >>> from pyhealth.tasks import DeIDNERTask - >>> from pyhealth.models import BertDeID + >>> from pyhealth.models import TransformerDeID >>> dataset = PhysioNetDeIDDataset(root="/path/to/data") >>> samples = dataset.set_task(DeIDNERTask()) - >>> model = BertDeID(dataset=samples) + >>> model = TransformerDeID(dataset=samples) # BERT + >>> model = TransformerDeID(dataset=samples, model_name="roberta-base") """ def __init__( @@ -118,13 +119,13 @@ def __init__( max_length: int = 512, dropout: float = 0.1, ): - super(BertDeID, self).__init__(dataset=dataset) + super(TransformerDeID, self).__init__(dataset=dataset) assert len(self.feature_keys) == 1, ( - "BertDeID expects exactly one input feature (text)." + "TransformerDeID expects exactly one input feature (text)." ) assert len(self.label_keys) == 1, ( - "BertDeID expects exactly one label key." + "TransformerDeID expects exactly one label key." ) self.feature_key = self.feature_keys[0] self.label_key = self.label_keys[0] diff --git a/tests/core/test_bert_deid.py b/tests/core/test_bert_deid.py index dbc9efce8..e47b69548 100644 --- a/tests/core/test_bert_deid.py +++ b/tests/core/test_bert_deid.py @@ -1,5 +1,5 @@ """ -Unit tests for the BertDeID model. +Unit tests for the TransformerDeID model. Author: Matt McKenna (mtm16@illinois.edu) @@ -11,7 +11,7 @@ from pyhealth.models.bert_deid import ( IGNORE_INDEX, LABEL_VOCAB, - BertDeID, + TransformerDeID, align_labels, ) from pyhealth.processors.text_processor import TextProcessor @@ -82,11 +82,11 @@ def test_all_special_tokens(self): self.assertEqual(result, [IGNORE_INDEX, IGNORE_INDEX]) -class TestBertDeIDInit(unittest.TestCase): +class TestTransformerDeIDInit(unittest.TestCase): @classmethod def setUpClass(cls): cls.dataset = _make_dataset() - cls.model = BertDeID(dataset=cls.dataset) + cls.model = TransformerDeID(dataset=cls.dataset) @classmethod def tearDownClass(cls): From 74cf0d5d052860ab31cfa49516ab4fb7b0f86f15 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Fri, 10 Apr 2026 22:45:08 -0700 Subject: [PATCH 05/15] Refactor TransformerDeID model: rename file and update tests --- pyhealth/models/__init__.py | 2 +- .../{bert_deid.py => transformer_deid.py} | 63 ++++++++++++++++- ..._bert_deid.py => test_transformer_deid.py} | 67 ++++++++++++++++++- 3 files changed, 129 insertions(+), 3 deletions(-) rename pyhealth/models/{bert_deid.py => transformer_deid.py} (68%) rename tests/core/{test_bert_deid.py => test_transformer_deid.py} (60%) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 6b660a829..deabaf95c 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,7 +1,7 @@ from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare from .agent import Agent, AgentLayer from .base_model import BaseModel -from .bert_deid import TransformerDeID +from .transformer_deid import TransformerDeID from .biot import BIOT from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer diff --git a/pyhealth/models/bert_deid.py b/pyhealth/models/transformer_deid.py similarity index 68% rename from pyhealth/models/bert_deid.py rename to pyhealth/models/transformer_deid.py index bb5950671..6d14cfe38 100644 --- a/pyhealth/models/bert_deid.py +++ b/pyhealth/models/transformer_deid.py @@ -156,4 +156,65 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: Returns: Dict with keys: loss, logit, y_prob, y_true. """ - raise NotImplementedError("TODO") + texts: List[str] = kwargs[self.feature_key] + label_strings: List[str] = kwargs[self.label_key] + + # Tokenize with is_split_into_words=True so the tokenizer + # knows word boundaries and word_ids() works correctly. + words_batch = [t.split(" ") for t in texts] + encoding = self.tokenizer( + words_batch, + is_split_into_words=True, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + + # Convert word-level label strings to indices, then align + # to subword tokens. Positions that should be ignored during + # loss (special tokens, non-first subtokens, padding) get + # IGNORE_INDEX (-100), which cross-entropy skips. + aligned_labels = [] + for i, label_str in enumerate(label_strings): + word_labels = [ + self.label_vocab[lbl] for lbl in label_str.split(" ") + ] + word_ids = encoding.word_ids(batch_index=i) + aligned_labels.append(align_labels(word_ids, word_labels)) + + labels = torch.tensor(aligned_labels, dtype=torch.long) + + # Move to device + input_ids = encoding["input_ids"].to(self.device) + attention_mask = encoding["attention_mask"].to(self.device) + labels = labels.to(self.device) + + # Encoder -> dropout -> classifier (per-token logits) + hidden_states = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state + logits = self.classifier(self.dropout(hidden_states)) + + # Token-level cross-entropy, ignoring padded/special positions. + # We can't use BaseModel.get_loss_function() because it assumes + # one label per sample. Instead we call cross_entropy directly + # with ignore_index to skip special tokens and non-first subtokens. + # Flatten + ignore_index pattern from HuggingFace's + # BertForTokenClassification.forward(). + loss = nn.functional.cross_entropy( + logits.view(-1, self.num_labels), + labels.view(-1), + ignore_index=IGNORE_INDEX, + ) + + # Per-token probabilities via softmax. + y_prob = torch.softmax(logits, dim=-1) + + return { + "loss": loss, + "logit": logits, + "y_prob": y_prob, + "y_true": labels, + } diff --git a/tests/core/test_bert_deid.py b/tests/core/test_transformer_deid.py similarity index 60% rename from tests/core/test_bert_deid.py rename to tests/core/test_transformer_deid.py index e47b69548..727691a3b 100644 --- a/tests/core/test_bert_deid.py +++ b/tests/core/test_transformer_deid.py @@ -7,8 +7,10 @@ import unittest +import torch + from pyhealth.datasets import create_sample_dataset -from pyhealth.models.bert_deid import ( +from pyhealth.models.transformer_deid import ( IGNORE_INDEX, LABEL_VOCAB, TransformerDeID, @@ -109,5 +111,68 @@ def test_encoder_hidden_size(self): self.assertEqual(self.model.encoder.config.hidden_size, 768) +class TestTransformerDeIDForward(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = TransformerDeID(dataset=cls.dataset) + cls.model.eval() + # Run a forward pass with raw strings (same format as task output). + with torch.no_grad(): + cls.result = cls.model( + text=[ + "Patient John Smith was seen", + "Admitted on 01/15/2024 to clinic", + ], + labels=[ + "O B-NAME I-NAME O O", + "O O B-DATE O O", + ], + ) + + @classmethod + def tearDownClass(cls): + cls.dataset.close() + + def test_output_has_required_keys(self): + for key in ("loss", "logit", "y_prob", "y_true"): + self.assertIn(key, self.result) + + def test_loss_is_scalar(self): + self.assertEqual(self.result["loss"].dim(), 0) + + def test_logit_shape(self): + """logit should be (batch, seq_len, num_labels).""" + logit = self.result["logit"] + self.assertEqual(logit.shape[0], 2) # batch size + self.assertEqual(logit.shape[2], 15) # num labels + + def test_y_prob_shape_matches_logit(self): + self.assertEqual( + self.result["y_prob"].shape, self.result["logit"].shape + ) + + def test_y_prob_sums_to_one(self): + """Softmax probabilities should sum to ~1 along label dim.""" + sums = self.result["y_prob"].sum(dim=-1) + self.assertTrue(torch.allclose(sums, torch.ones_like(sums), atol=1e-5)) + + def test_backward(self): + """Loss backward should produce gradients.""" + dataset = _make_dataset() + model = TransformerDeID(dataset=dataset) + result = model( + text=["Patient John Smith was seen"], + labels=["O B-NAME I-NAME O O"], + ) + result["loss"].backward() + has_grad = any( + p.requires_grad and p.grad is not None + for p in model.parameters() + ) + self.assertTrue(has_grad) + dataset.close() + + if __name__ == "__main__": unittest.main() From 5de2b58b780005995942d0deed8b46fa8abe19fe Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 09:50:35 -0700 Subject: [PATCH 06/15] WIP e2e script --- .../physionet_deid_ner_transformer_deid.py | 126 ++++++++++++++++++ pyhealth/tasks/deid_ner.py | 1 + 2 files changed, 127 insertions(+) create mode 100644 examples/physionet_deid_ner_transformer_deid.py diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py new file mode 100644 index 000000000..83c8894c8 --- /dev/null +++ b/examples/physionet_deid_ner_transformer_deid.py @@ -0,0 +1,126 @@ +"""Train and evaluate TransformerDeID on PhysioNet de-identification. + +End-to-end example: load data, train BERT-base on token-level NER for +PHI detection, and report binary (PHI vs non-PHI) precision/recall/F1. + +Paper: Johnson et al. "Deidentification of free-text medical records + using pre-trained bidirectional transformers." CHIL, 2020. + +Script structure follows examples/cardiology_detection_isAR_SparcNet.py. + +Hyperparameters follow the paper (Section 3.4): + - Learning rate: 5e-5 + - Batch size: 8 + - Epochs: 3 + - Weight decay: 0.01 + +Usage: + python examples/physionet_deid_ner_transformer_deid.py \ + --data_root path/to/deidentifiedmedicaltext/1.0 + +Author: + Matt McKenna (mtm16@illinois.edu) +""" + +import argparse + +import torch +from sklearn.metrics import precision_score, recall_score, f1_score + +from pyhealth.datasets import PhysioNetDeIDDataset, get_dataloader +from pyhealth.datasets.splitter import split_by_patient +from pyhealth.models.transformer_deid import ( + IGNORE_INDEX, + TransformerDeID, +) +from pyhealth.tasks import DeIDNERTask +from pyhealth.trainer import Trainer + + +def compute_metrics(model, dataloader): + """Binary PHI vs non-PHI token-level metrics. + + Any non-O label counts as PHI. + """ + all_true, all_pred = [], [] + model.eval() + with torch.no_grad(): + for batch in dataloader: + result = model(**batch) + preds = result["logit"].argmax(dim=-1) + labels = result["y_true"] + mask = labels != IGNORE_INDEX + all_true.extend(labels[mask].cpu().tolist()) + all_pred.extend(preds[mask].cpu().tolist()) + + # Convert to binary: O (index 0) = 0, any PHI = 1. + true_bin = [0 if t == 0 else 1 for t in all_true] + pred_bin = [0 if p == 0 else 1 for p in all_pred] + return { + "precision": precision_score(true_bin, pred_bin), + "recall": recall_score(true_bin, pred_bin), + "f1": f1_score(true_bin, pred_bin), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + required=True, + help="Path to deidentifiedmedicaltext/1.0 directory", + ) + parser.add_argument("--model_name", type=str, default="bert-base-uncased") + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + # 1. Load dataset and set task. + print("Loading dataset...") + dataset = PhysioNetDeIDDataset(root=args.data_root) + samples = dataset.set_task(DeIDNERTask()) + print(f" Patients: {len(dataset.unique_patient_ids)}, Samples: {len(samples)}") + + # 2. Split by patient (80/10/10) so no patient's notes appear in + # both train and test. + train_data, val_data, test_data = split_by_patient( + samples, [0.8, 0.1, 0.1], seed=args.seed + ) + train_loader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader(test_data, batch_size=args.batch_size, shuffle=False) + + # 3. Create model. + model = TransformerDeID( + dataset=samples, + model_name=args.model_name, + ) + + # 4. Train using PyHealth's Trainer. + device = "cuda" if torch.cuda.is_available() else "cpu" + trainer = Trainer(model=model, device=device) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": args.lr}, + weight_decay=0.01, + monitor="loss", + monitor_criterion="min", + ) + + # 5. Evaluate on test set. + print("\n=== Test Set Results (binary PHI vs non-PHI) ===") + metrics = compute_metrics(model, test_loader) + for k, v in metrics.items(): + print(f" {k}: {v:.4f}") + + samples.close() + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py index 58fe21867..80bf9c94d 100644 --- a/pyhealth/tasks/deid_ner.py +++ b/pyhealth/tasks/deid_ner.py @@ -72,6 +72,7 @@ def __call__(self, patient: Patient) -> List[Dict]: samples = [] for event in events: samples.append({ + "patient_id": patient.patient_id, "text": event["text"], "labels": event["labels"], }) From 64d886e4ac0407e1e353ea3041305bf18c217191 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 13:09:18 -0700 Subject: [PATCH 07/15] add windowing --- .../physionet_deid_ner_transformer_deid.py | 9 +++- pyhealth/tasks/deid_ner.py | 48 ++++++++++++++++--- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py index 83c8894c8..36abd2260 100644 --- a/examples/physionet_deid_ner_transformer_deid.py +++ b/examples/physionet_deid_ner_transformer_deid.py @@ -75,13 +75,20 @@ def main(): parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--window_size", type=int, default=None, + help="Token window size (default: no windowing)") + parser.add_argument("--window_overlap", type=int, default=0) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() # 1. Load dataset and set task. print("Loading dataset...") dataset = PhysioNetDeIDDataset(root=args.data_root) - samples = dataset.set_task(DeIDNERTask()) + task = DeIDNERTask( + window_size=args.window_size, + window_overlap=args.window_overlap, + ) + samples = dataset.set_task(task) print(f" Patients: {len(dataset.unique_patient_ids)}, Samples: {len(samples)}") # 2. Split by patient (80/10/10) so no patient's notes appear in diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py index 80bf9c94d..539c410fd 100644 --- a/pyhealth/tasks/deid_ner.py +++ b/pyhealth/tasks/deid_ner.py @@ -20,7 +20,7 @@ """ import logging -from typing import Dict, List, Type, Union +from typing import Dict, List, Optional, Type, Union from pyhealth.data import Event, Patient from pyhealth.processors.text_processor import TextProcessor @@ -36,6 +36,15 @@ class DeIDNERTask(BaseTask): 7 PHI categories: AGE, CONTACT, DATE, ID, LOCATION, NAME, PROFESSION. + Supports optional overlapping windowing (paper Section 3.3) to + handle notes longer than BERT's 512 token limit. + + Args: + window_size: If set, split notes into overlapping windows of + this many tokens. Default None (no windowing). + window_overlap: Number of tokens shared between consecutive + windows. Default 0. + Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. @@ -47,6 +56,8 @@ class DeIDNERTask(BaseTask): >>> dataset = PhysioNetDeIDDataset(root="/path/to/data") >>> task = DeIDNERTask() >>> samples = dataset.set_task(task) + >>> task_windowed = DeIDNERTask(window_size=100, window_overlap=60) + >>> samples = dataset.set_task(task_windowed) """ task_name: str = "DeIDNER" @@ -55,6 +66,14 @@ class DeIDNERTask(BaseTask): # encodes them into label indices. output_schema: Dict[str, Union[str, Type]] = {"labels": TextProcessor} + def __init__( + self, + window_size: Optional[int] = None, + window_overlap: int = 0, + ): + self.window_size = window_size + self.window_overlap = window_overlap + def __call__(self, patient: Patient) -> List[Dict]: """Generate NER samples from a patient's clinical notes. @@ -71,10 +90,27 @@ def __call__(self, patient: Patient) -> List[Dict]: samples = [] for event in events: - samples.append({ - "patient_id": patient.patient_id, - "text": event["text"], - "labels": event["labels"], - }) + words = event["text"].split(" ") + labels = event["labels"].split(" ") + + if self.window_size is None: + # No windowing: one sample per note. + samples.append({ + "patient_id": patient.patient_id, + "text": event["text"], + "labels": event["labels"], + }) + else: + # Overlapping windows (paper Section 3.3). + step = self.window_size - self.window_overlap + idx = 0 + while idx < len(words): + end = min(idx + self.window_size, len(words)) + samples.append({ + "patient_id": patient.patient_id, + "text": " ".join(words[idx:end]), + "labels": " ".join(labels[idx:end]), + }) + idx += step return samples From b39738523bf30a3fad8b6fe68db483082b6eec7a Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 13:55:56 -0700 Subject: [PATCH 08/15] merge windows before deciding if PHI --- .../physionet_deid_ner_transformer_deid.py | 52 ++++++++++++++++--- pyhealth/tasks/deid_ner.py | 5 ++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py index 36abd2260..059fd4972 100644 --- a/examples/physionet_deid_ner_transformer_deid.py +++ b/examples/physionet_deid_ner_transformer_deid.py @@ -23,7 +23,9 @@ """ import argparse +from collections import defaultdict +import numpy as np import torch from sklearn.metrics import precision_score, recall_score, f1_score @@ -31,6 +33,7 @@ from pyhealth.datasets.splitter import split_by_patient from pyhealth.models.transformer_deid import ( IGNORE_INDEX, + LABEL_VOCAB, TransformerDeID, ) from pyhealth.tasks import DeIDNERTask @@ -38,20 +41,53 @@ def compute_metrics(model, dataloader): - """Binary PHI vs non-PHI token-level metrics. + """Binary PHI vs non-PHI token-level metrics with window merging. - Any non-O label counts as PHI. + When windowing is used, multiple windows may cover the same token. + We merge by taking the non-O prediction with highest probability + (paper Section 3.3). Without windowing, each token appears once + so no merging is needed. """ - all_true, all_pred = [], [] + # Collect per-token gold labels and prediction probabilities, + # keyed by (patient_id, note_id, absolute_token_position). + token_gold = {} + token_preds = defaultdict(list) + model.eval() with torch.no_grad(): for batch in dataloader: result = model(**batch) - preds = result["logit"].argmax(dim=-1) - labels = result["y_true"] - mask = labels != IGNORE_INDEX - all_true.extend(labels[mask].cpu().tolist()) - all_pred.extend(preds[mask].cpu().tolist()) + probs = result["y_prob"] # (batch, seq_len, num_labels) + labels = result["y_true"] # (batch, seq_len) + patient_ids = batch["patient_id"] + note_ids = batch["note_id"] + token_starts = batch["token_start"] + + for i in range(len(patient_ids)): + pid = patient_ids[i] + nid = note_ids[i] + start = int(token_starts[i]) + word_idx = 0 + for j in range(labels.shape[1]): + if labels[i, j].item() == IGNORE_INDEX: + continue + key = (pid, nid, start + word_idx) + token_gold[key] = labels[i, j].item() + token_preds[key].append(probs[i, j].cpu().numpy()) + word_idx += 1 + + # Merge overlapping predictions (paper Section 3.3): + # if any window predicts non-O, take the non-O with highest score. + all_true, all_pred = [], [] + for key in sorted(token_gold): + all_true.append(token_gold[key]) + preds = token_preds[key] + non_o = [(p, p[1:].max()) for p in preds if np.argmax(p) != 0] + if non_o: + merged = max(non_o, key=lambda x: x[1])[0] + else: + merged = np.mean(preds, axis=0) + all_pred.append(int(np.argmax(merged))) # Convert to binary: O (index 0) = 0, any PHI = 1. true_bin = [0 if t == 0 else 1 for t in all_true] diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py index 539c410fd..9324a31b6 100644 --- a/pyhealth/tasks/deid_ner.py +++ b/pyhealth/tasks/deid_ner.py @@ -90,6 +90,7 @@ def __call__(self, patient: Patient) -> List[Dict]: samples = [] for event in events: + note_id = event["note_id"] words = event["text"].split(" ") labels = event["labels"].split(" ") @@ -97,6 +98,8 @@ def __call__(self, patient: Patient) -> List[Dict]: # No windowing: one sample per note. samples.append({ "patient_id": patient.patient_id, + "note_id": note_id, + "token_start": "0", "text": event["text"], "labels": event["labels"], }) @@ -108,6 +111,8 @@ def __call__(self, patient: Patient) -> List[Dict]: end = min(idx + self.window_size, len(words)) samples.append({ "patient_id": patient.patient_id, + "note_id": note_id, + "token_start": str(idx), "text": " ".join(words[idx:end]), "labels": " ".join(labels[idx:end]), }) From ff6ea37fa96e1058d1510738a87b728a44821427 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 16:32:00 -0700 Subject: [PATCH 09/15] fix for roberta --- pyhealth/models/transformer_deid.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyhealth/models/transformer_deid.py b/pyhealth/models/transformer_deid.py index 6d14cfe38..631ec45eb 100644 --- a/pyhealth/models/transformer_deid.py +++ b/pyhealth/models/transformer_deid.py @@ -135,7 +135,11 @@ def __init__( self.label_vocab = LABEL_VOCAB self.num_labels = len(LABEL_VOCAB) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + # add_prefix_space=True is required for RoBERTa when using + # is_split_into_words=True in the forward pass. + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, add_prefix_space=True + ) self.encoder = AutoModel.from_pretrained( model_name, hidden_dropout_prob=dropout, From d80ad933438f6016c100ac1ed7f83e6108247fb1 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 20:20:45 -0700 Subject: [PATCH 10/15] clarify comments --- pyhealth/datasets/physionet_deid.py | 9 +++++++-- pyhealth/models/transformer_deid.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pyhealth/datasets/physionet_deid.py b/pyhealth/datasets/physionet_deid.py index 7e7599d52..d8e190977 100644 --- a/pyhealth/datasets/physionet_deid.py +++ b/pyhealth/datasets/physionet_deid.py @@ -11,10 +11,15 @@ Paper link: https://doi.org/10.1186/1472-6947-8-32 -BIO tagging and PHI classification adapted from the bert-deid reference -implementation by Johnson et al.: +PHI category mapping in classify_phi() inspired by the label groupings +in the bert-deid reference implementation by Johnson et al.: https://github.com/alistairewj/bert-deid/blob/master/bert_deid/label.py +Task paper: + Johnson, Alistair E.W., et al. "Deidentification of free-text medical + records using pre-trained bidirectional transformers." Proceedings of + the ACM Conference on Health, Inference, and Learning (CHIL), 2020. + Author: Matt McKenna (mtm16@illinois.edu) """ diff --git a/pyhealth/models/transformer_deid.py b/pyhealth/models/transformer_deid.py index 631ec45eb..e944ddcb6 100644 --- a/pyhealth/models/transformer_deid.py +++ b/pyhealth/models/transformer_deid.py @@ -17,8 +17,8 @@ TransformersModel (pyhealth/models/transformers_model.py), adapted for token-level classification instead of sequence-level. -Subword alignment logic adapted from the reference implementation: - https://github.com/alistairewj/bert-deid +Subword alignment follows the standard HuggingFace token +classification pattern (see BertForTokenClassification). Author: Matt McKenna (mtm16@illinois.edu) From b9da8a62b241b3c5b8ec20f1af704535d76738a7 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 20:31:34 -0700 Subject: [PATCH 11/15] Add tests for windowing --- tests/core/test_physionet_deid.py | 47 +++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/core/test_physionet_deid.py b/tests/core/test_physionet_deid.py index 12de4b7a6..6cfb47c6e 100644 --- a/tests/core/test_physionet_deid.py +++ b/tests/core/test_physionet_deid.py @@ -135,6 +135,53 @@ def test_task_labels_are_valid_bio(self): for label in sample["labels"].split(" "): self.assertIn(label, valid) + def test_task_sample_has_patient_id(self): + self.assertIn("patient_id", self.samples[0]) + + +class TestDeIDNERTaskWindowing(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "physionet_deid" + ) + cls.cache_dir = tempfile.TemporaryDirectory() + cls.dataset = PhysioNetDeIDDataset( + root=str(cls.root), cache_dir=cls.cache_dir.name + ) + cls.task = DeIDNERTask(window_size=10, window_overlap=5) + cls.samples = cls.dataset.set_task(cls.task) + + @classmethod + def tearDownClass(cls): + cls.samples.close() + cls.cache_dir.cleanup() + metadata = cls.root / "physionet_deid_metadata.csv" + if metadata.exists(): + metadata.unlink() + + def test_windowing_produces_more_samples(self): + """Windowing should produce more samples than the 4 notes.""" + self.assertGreater(len(self.samples), 4) + + def test_window_size_respected(self): + """Each window should have at most window_size tokens.""" + for sample in self.samples: + tokens = sample["text"].split(" ") + self.assertLessEqual(len(tokens), 10) + + def test_window_text_and_labels_same_length(self): + for sample in self.samples: + tokens = sample["text"].split(" ") + labels = sample["labels"].split(" ") + self.assertEqual(len(tokens), len(labels)) + + def test_window_has_patient_id(self): + self.assertIn("patient_id", self.samples[0]) + if __name__ == "__main__": unittest.main() From 62888388aa939b7cdd6ed88d54302d8e9acff3d0 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 22:56:49 -0700 Subject: [PATCH 12/15] add rst files --- docs/api/datasets.rst | 1 + ...pyhealth.datasets.PhysioNetDeIDDataset.rst | 9 ++++++++ docs/api/models.rst | 1 + .../pyhealth.models.TransformerDeID.rst | 9 ++++++++ docs/api/tasks.rst | 1 + docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst | 7 ++++++ .../physionet_deid_ner_transformer_deid.py | 23 +++++++++++++++++++ pyhealth/datasets/physionet_deid.py | 3 ++- 8 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst create mode 100644 docs/api/models/pyhealth.models.TransformerDeID.rst create mode 100644 docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..8d9a59d21 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -238,6 +238,7 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.PhysioNetDeIDDataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset datasets/pyhealth.datasets.ClinVarDataset diff --git a/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst b/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst new file mode 100644 index 000000000..4e04cd629 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.PhysioNetDeIDDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.PhysioNetDeIDDataset +======================================= + +The PhysioNet De-Identification dataset. For more information see `here `_. Access requires PhysioNet credentialing. + +.. autoclass:: pyhealth.datasets.PhysioNetDeIDDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..166402e86 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -177,6 +177,7 @@ API Reference models/pyhealth.models.GNN models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel + models/pyhealth.models.TransformerDeID models/pyhealth.models.RETAIN models/pyhealth.models.GAMENet models/pyhealth.models.GraphCare diff --git a/docs/api/models/pyhealth.models.TransformerDeID.rst b/docs/api/models/pyhealth.models.TransformerDeID.rst new file mode 100644 index 000000000..d07aa94aa --- /dev/null +++ b/docs/api/models/pyhealth.models.TransformerDeID.rst @@ -0,0 +1,9 @@ +pyhealth.models.TransformerDeID +=================================== + +Transformer-based token classifier for clinical text de-identification. + +.. autoclass:: pyhealth.models.TransformerDeID + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..23a4e06e5 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -224,6 +224,7 @@ Available Tasks Sleep Staging v2 Benchmark EHRShot ChestX-ray14 Binary Classification + De-Identification NER ChestX-ray14 Multilabel Classification Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) diff --git a/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst b/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst new file mode 100644 index 000000000..2b7428f6e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DeIDNERTask.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.DeIDNERTask +======================================= + +.. autoclass:: pyhealth.tasks.DeIDNERTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py index 059fd4972..b1a597944 100644 --- a/examples/physionet_deid_ner_transformer_deid.py +++ b/examples/physionet_deid_ner_transformer_deid.py @@ -14,10 +14,33 @@ - Epochs: 3 - Weight decay: 0.01 +Ablation results (3 epochs, 80/10/10 patient split, seed=42): + + Config Precision Recall F1 + BERT, no window 95.1% 70.3% 80.8% + BERT, win=64/32 94.1% 69.0% 79.6% + BERT, win=100/60 86.9% 75.7% 80.9% + BERT, win=200/100 94.7% 69.4% 80.1% + RoBERTa, no window 98.1% 64.7% 78.0% + RoBERTa, win=100/60 82.6% 68.6% 75.0% + + BERT with window=100/60 achieves the best F1 (80.9%), matching the + paper's window configuration. Windowing improves recall by allowing + BERT to see tokens beyond the 512 truncation limit. RoBERTa has + higher precision but lower recall than BERT across configurations. + Usage: python examples/physionet_deid_ner_transformer_deid.py \ --data_root path/to/deidentifiedmedicaltext/1.0 + # With windowing (paper Section 3.3): + python examples/physionet_deid_ner_transformer_deid.py \ + --data_root path/to/data --window_size 100 --window_overlap 60 + + # With RoBERTa: + python examples/physionet_deid_ner_transformer_deid.py \ + --data_root path/to/data --model_name roberta-base + Author: Matt McKenna (mtm16@illinois.edu) """ diff --git a/pyhealth/datasets/physionet_deid.py b/pyhealth/datasets/physionet_deid.py index d8e190977..5bd6acc95 100644 --- a/pyhealth/datasets/physionet_deid.py +++ b/pyhealth/datasets/physionet_deid.py @@ -328,7 +328,8 @@ def _index_data(self, root: str) -> pd.DataFrame: ): pid, nid = key orig = orig_records[key] - deid = deid_records.get(key, "") # defensive: missing key yields empty string + # Missing key yields empty string (no deid version). + deid = deid_records.get(key, "") spans = phi_spans_in_original(orig, deid) tagged = bio_tag(orig, spans) From 238f98b15fd9ba3a1b4c63efa035be4e7c790805 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Sun, 12 Apr 2026 22:58:37 -0700 Subject: [PATCH 13/15] remove dead code --- examples/physionet_deid_ner_transformer_deid.py | 1 - pyhealth/tasks/deid_ner.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/examples/physionet_deid_ner_transformer_deid.py b/examples/physionet_deid_ner_transformer_deid.py index b1a597944..eaa421bd9 100644 --- a/examples/physionet_deid_ner_transformer_deid.py +++ b/examples/physionet_deid_ner_transformer_deid.py @@ -56,7 +56,6 @@ from pyhealth.datasets.splitter import split_by_patient from pyhealth.models.transformer_deid import ( IGNORE_INDEX, - LABEL_VOCAB, TransformerDeID, ) from pyhealth.tasks import DeIDNERTask diff --git a/pyhealth/tasks/deid_ner.py b/pyhealth/tasks/deid_ner.py index 9324a31b6..67b212bac 100644 --- a/pyhealth/tasks/deid_ner.py +++ b/pyhealth/tasks/deid_ner.py @@ -19,15 +19,12 @@ Matt McKenna (mtm16@illinois.edu) """ -import logging from typing import Dict, List, Optional, Type, Union from pyhealth.data import Event, Patient from pyhealth.processors.text_processor import TextProcessor from pyhealth.tasks import BaseTask -logger = logging.getLogger(__name__) - class DeIDNERTask(BaseTask): """Token-level NER task for clinical text de-identification. From ee1fb3a85ed7048f02fe7be008a2e5d764f8d33c Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Mon, 13 Apr 2026 21:03:53 -0700 Subject: [PATCH 14/15] add deidentify method so i can test it out manually --- pyhealth/models/transformer_deid.py | 39 +++++++++++++++++++++++++++++ tests/core/test_transformer_deid.py | 26 +++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/pyhealth/models/transformer_deid.py b/pyhealth/models/transformer_deid.py index e944ddcb6..964942f5a 100644 --- a/pyhealth/models/transformer_deid.py +++ b/pyhealth/models/transformer_deid.py @@ -222,3 +222,42 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: "y_prob": y_prob, "y_true": labels, } + + def deidentify(self, text: str, redact: str = "[REDACTED]") -> str: + """Replace PHI in a clinical note with a redaction marker. + + Args: + text: Raw clinical note as a string. + redact: Replacement string for PHI tokens. + + Returns: + The note with PHI tokens replaced. + + Example:: + >>> model.deidentify("Patient John Smith was seen") + 'Patient [REDACTED] [REDACTED] was seen' + """ + words = text.split() + # Forward pass with dummy labels (all O) since we only + # need predictions, not loss. + dummy_labels = " ".join(["O"] * len(words)) + self.eval() + with torch.no_grad(): + result = self(text=[text], labels=[dummy_labels]) + + preds = result["logit"][0].argmax(dim=-1) + y_true = result["y_true"][0] + + # Map predictions back to words using the non-ignored positions. + word_idx = 0 + output = [] + for j in range(len(preds)): + if y_true[j].item() == IGNORE_INDEX: + continue + if preds[j].item() != 0: # non-O = PHI + output.append(redact) + else: + output.append(words[word_idx]) + word_idx += 1 + + return " ".join(output) diff --git a/tests/core/test_transformer_deid.py b/tests/core/test_transformer_deid.py index 727691a3b..417627e8f 100644 --- a/tests/core/test_transformer_deid.py +++ b/tests/core/test_transformer_deid.py @@ -174,5 +174,31 @@ def test_backward(self): dataset.close() +class TestTransformerDeIDDeidentify(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = TransformerDeID(dataset=cls.dataset) + cls.model.eval() + + @classmethod + def tearDownClass(cls): + cls.dataset.close() + + def test_returns_string(self): + result = self.model.deidentify("Patient John Smith was seen") + self.assertIsInstance(result, str) + + def test_same_word_count(self): + """Output should have same number of words (redacted or not).""" + text = "Patient John Smith was seen" + result = self.model.deidentify(text) + self.assertEqual(len(result.split()), len(text.split())) + + def test_custom_redact_marker(self): + result = self.model.deidentify("Patient John", redact="[PHI]") + self.assertNotIn("[REDACTED]", result) + + if __name__ == "__main__": unittest.main() From 92dcca5330b264a96e67ddbbd8cf17f662de1099 Mon Sep 17 00:00:00 2001 From: Matt McKenna Date: Mon, 13 Apr 2026 22:37:43 -0700 Subject: [PATCH 15/15] combine tests to make them faster --- tests/core/test_transformer_deid.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/tests/core/test_transformer_deid.py b/tests/core/test_transformer_deid.py index 417627e8f..e87a68c41 100644 --- a/tests/core/test_transformer_deid.py +++ b/tests/core/test_transformer_deid.py @@ -159,43 +159,32 @@ def test_y_prob_sums_to_one(self): def test_backward(self): """Loss backward should produce gradients.""" - dataset = _make_dataset() - model = TransformerDeID(dataset=dataset) - result = model( + # Need train mode and fresh forward pass for gradients. + self.model.train() + result = self.model( text=["Patient John Smith was seen"], labels=["O B-NAME I-NAME O O"], ) result["loss"].backward() has_grad = any( p.requires_grad and p.grad is not None - for p in model.parameters() + for p in self.model.parameters() ) self.assertTrue(has_grad) - dataset.close() + self.model.eval() + self.model.zero_grad() - -class TestTransformerDeIDDeidentify(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dataset = _make_dataset() - cls.model = TransformerDeID(dataset=cls.dataset) - cls.model.eval() - - @classmethod - def tearDownClass(cls): - cls.dataset.close() - - def test_returns_string(self): + def test_deidentify_returns_string(self): result = self.model.deidentify("Patient John Smith was seen") self.assertIsInstance(result, str) - def test_same_word_count(self): + def test_deidentify_same_word_count(self): """Output should have same number of words (redacted or not).""" text = "Patient John Smith was seen" result = self.model.deidentify(text) self.assertEqual(len(result.split()), len(text.split())) - def test_custom_redact_marker(self): + def test_deidentify_custom_redact_marker(self): result = self.model.deidentify("Patient John", redact="[PHI]") self.assertNotIn("[REDACTED]", result)