diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..ea772bfec 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -224,6 +224,7 @@ Available Datasets datasets/pyhealth.datasets.SampleDataset datasets/pyhealth.datasets.MIMIC3Dataset datasets/pyhealth.datasets.MIMIC4Dataset + datasets/pyhealth.datasets.MIMIC4FHIRDataset datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset datasets/pyhealth.datasets.eICUDataset diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst new file mode 100644 index 000000000..1a19fc5e9 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst @@ -0,0 +1,70 @@ +pyhealth.datasets.MIMIC4FHIRDataset +===================================== + +`MIMIC-IV on FHIR `_ NDJSON ingest +for CEHR-style token sequences used with +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and +:class:`~pyhealth.models.EHRMambaCEHR`. + +YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml``. Unlike the +earlier nested-object approach, the YAML now declares a normal ``tables:`` +schema for flattened FHIR resources (``patient``, ``encounter``, ``condition``, +``observation``, ``medication_request``, ``procedure``). The class subclasses +:class:`~pyhealth.datasets.BaseDataset` and builds a standard Polars +``global_event_df`` backed by cached Parquet (``global_event_df.parquet/part-*.parquet``), +same tabular path as other datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`, +:meth:`iter_patients`, :meth:`get_patient`, etc. + +**Ingest (out-of-core).** Matching ``*.ndjson`` / ``*.ndjson.gz`` files are read +**line by line**; each resource is normalized into a flattened per-resource +Parquet table under ``cache/flattened_tables/``. Those tables are then fed +through the regular YAML-driven :class:`~pyhealth.datasets.BaseDataset` loader to +materialize ``global_event_df``. This keeps FHIR aligned with PyHealth's usual +table-first pipeline instead of reparsing nested JSON per patient downstream. + +**``max_patients``.** When set, the loader selects the first *N* patient ids after +a **sorted** ``unique`` over the flattened patient table, filters every +normalized table to that cohort, and then builds ``global_event_df`` from the +filtered tables. Ingest still scans all matching NDJSON once unless you also +override ``glob_patterns`` / ``glob_pattern`` (defaults skip non-flattened PhysioNet shards). + +**Downstream memory (still important).** Streaming ingest avoids loading the +entire NDJSON corpus into RAM at once, but other steps can still be heavy on +large cohorts: ``global_event_df`` materialization, MPF vocabulary warmup, and +:meth:`set_task` still walk patients and samples; training needs RAM/VRAM for the +model and batches. For a **full** PhysioNet tree, plan for **large disk** +(flattened tables plus event cache), **comfortable system RAM** for Polars/PyArrow +and task pipelines, and restrict ``glob_patterns`` / ``glob_pattern`` or ``max_patients`` when +prototyping on a laptop. + +**Recommended hardware (informal)** + +Order-of-magnitude guides, not guarantees. Ingest footprint is **much smaller** +than “load everything into Python”; wall time still grows with **decompressed +NDJSON volume** and the amount of flattened table data produced. + +* **Smoke / CI** + Small on-disk fixtures (see tests and ``examples/mimic4fhir_mpf_ehrmamba.py``): + a recent laptop is sufficient. + +* **Laptop-scale real FHIR subset** + A **narrow** ``glob_patterns`` / ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps + cache and task passes manageable. **≥ 16 GB** system RAM is a practical + comfort target for Polars + trainer + OS; validate GPU **VRAM** for your + ``max_len`` and batch size. + +* **Full default globs on a complete export** + Favor **workstations or servers** with **fast SSD**, **large disk**, and + **ample RAM** for downstream steps—not because NDJSON is fully buffered in + memory during ingest, but because total work and caches still scale with the + full dataset. + +.. autoclass:: pyhealth.datasets.MIMIC4FHIRDataset + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.datasets.ConceptVocab + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..c7e9f2729 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -185,6 +185,7 @@ API Reference models/pyhealth.models.MoleRec models/pyhealth.models.Deepr models/pyhealth.models.EHRMamba + models/pyhealth.models.EHRMambaCEHR models/pyhealth.models.JambaEHR models/pyhealth.models.ContraWR models/pyhealth.models.SparcNet diff --git a/docs/api/models/pyhealth.models.EHRMambaCEHR.rst b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst new file mode 100644 index 000000000..79466cad3 --- /dev/null +++ b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst @@ -0,0 +1,12 @@ +pyhealth.models.EHRMambaCEHR +=================================== + +EHRMambaCEHR applies CEHR-style embeddings (:class:`~pyhealth.models.cehr_embeddings.MambaEmbeddingsForCEHR`) +and a stack of :class:`~pyhealth.models.MambaBlock` layers to a single FHIR token stream, for use with +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and +:class:`~pyhealth.datasets.mimic4_fhir.MIMIC4FHIRDataset`. + +.. autoclass:: pyhealth.models.EHRMambaCEHR + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..83790ca44 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -214,6 +214,7 @@ Available Tasks Drug Recommendation Length of Stay Prediction Medical Transcriptions Classification + MPF Clinical Prediction (FHIR) Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) Patient Linkage (MIMIC-III) diff --git a/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst new file mode 100644 index 000000000..ff66deb08 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst @@ -0,0 +1,12 @@ +pyhealth.tasks.mpf_clinical_prediction +====================================== + +Multitask Prompted Fine-tuning (MPF) style binary clinical prediction on FHIR +token timelines, paired with :class:`~pyhealth.datasets.MIMIC4FHIRDataset` and +:class:`~pyhealth.models.EHRMambaCEHR`. Based on CEHR / EHRMamba ideas; see the +paper linked in the course replication PR. + +.. autoclass:: pyhealth.tasks.MPFClinicalPredictionTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py new file mode 100644 index 000000000..a80d9c43c --- /dev/null +++ b/examples/mimic4fhir_mpf_ehrmamba.py @@ -0,0 +1,842 @@ +"""EHRMambaCEHR on MIMIC-IV FHIR NDJSON with MPF clinical prediction (ablations). + +Replication target: EHRMamba / CEHR-style modeling on tokenized FHIR timelines +(e.g. `arXiv:2405.14567 `_). This script is +runnable end-to-end on **synthetic** NDJSON (``--quick-test``) or on +credentialled MIMIC-IV on FHIR from PhysioNet. + +Experimental setup (for write-ups / PR): + * **Data**: Synthetic two-patient NDJSON (``--quick-test``) or disk NDJSON + under ``MIMIC4_FHIR_ROOT`` / ``--fhir-root``. + * **Task ablations**: ``max_len`` (context window), ``use_mpf`` vs generic + ````/```` boundaries (``--no-mpf``). + * **Model ablations**: ``hidden_dim`` (embedding width); optional dropout + fixed at 0.1 in this script. + * **Train**: Adam via :class:`~pyhealth.trainer.Trainer`, monitor ROC-AUC, + report test ROC-AUC / PR-AUC. + +**Ablation mode** (``--ablation``): sweeps a small grid on synthetic data only, +trains 1 epoch per config, and prints a comparison table. Use this to document +how task/model knobs affect metrics on the minimal fixture before scaling to +real FHIR. + +**Findings** (fill in after your runs; synthetic runs are noisy): + On ``--quick-test`` data, longer ``max_len`` and MPF specials typically + change logits enough to move AUC slightly; real MIMIC-IV FHIR runs are + needed for conclusive comparisons. Paste your table from ``--ablation`` + into the PR description. + +**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON into +flattened per-resource Parquet tables (bounded RAM during ingest). This example trains via +``dataset.set_task(MPFClinicalPredictionTask)`` → LitData-backed +:class:`~pyhealth.datasets.sample_dataset.SampleDataset` → +:class:`~pyhealth.trainer.Trainer` (PyHealth’s standard path), instead of +materializing all samples with ``gather_samples()``. Prefer ``--max-patients`` to + bound ingest when possible. Very large cohorts still need RAM/disk for task +caches and MPF vocabulary warmup. + +**Offline flattened tables (NDJSON normalization already done):** pass +``--prebuilt-global-event-dir`` pointing at a directory containing the normalized +FHIR tables (``patient.parquet``, ``encounter.parquet``, ``condition.parquet``, +etc.). The example seeds ``flattened_tables/`` under the usual PyHealth cache UUID, +then lets :class:`~pyhealth.datasets.BaseDataset` rebuild +``global_event_df.parquet/`` from those tables — the downstream path is still +``global_event_df`` → :class:`~pyhealth.data.Patient` → +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` → +:class:`~pyhealth.trainer.Trainer``. Use ``--fhir-root`` / ``--glob-pattern`` / +``--max-patients -1`` matching the ingest fingerprint. +``--train-patient-cap`` restricts task transforms via ``task.pre_filter`` using a +label-aware deterministic patient subset. The full ``unique_patient_ids`` scan and MPF vocab warmup +in the dataset still walk the cached cohort. + +**Approximate minimum specs** (``--quick-test``, CPU, synthetic 2-patient +fixture; measured once on macOS/arm64 with ``/usr/bin/time -l``): peak RSS +~**600–700 MiB**, wall **~10–15 s** for two short epochs. Real NDJSON/GZ at scale +needs proportionally more RAM, disk, and time; GPU helps training, not the +current all-in-RAM parse. + +Usage: + cd PyHealth && PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test + PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test --ablation + export MIMIC4_FHIR_ROOT=/path/to/fhir + pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py --fhir-root "$MIMIC4_FHIR_ROOT" + + # Prebuilt flattened FHIR tables (skip NDJSON normalization); cap patients for a smoke train + pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py \\ + --prebuilt-global-event-dir /path/to/flattened_table_dir \\ + --fhir-root /same/as/ndjson/ingest/root --glob-pattern 'Mimic*.ndjson.gz' --max-patients -1 \\ + --train-patient-cap 2048 --epochs 2 \\ + --ntfy-url 'https://ntfy.sh/your-topic' +""" + +from __future__ import annotations + +import argparse +import os +import random +import re +import shutil +import sys +import tempfile +import time +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any, Dict, List, Optional + +_parser = argparse.ArgumentParser(description="EHRMambaCEHR on MIMIC-IV FHIR (MPF)") +_parser.add_argument( + "--gpu", + type=int, + default=None, + help="GPU index; sets CUDA_VISIBLE_DEVICES.", +) +_parser.add_argument( + "--fhir-root", + type=str, + default=None, + help="Root directory with NDJSON (default: MIMIC4_FHIR_ROOT env).", +) +_parser.add_argument( + "--glob-pattern", + type=str, + default=None, + help=( + "Override glob for NDJSON/NDJSON.GZ (default: yaml **/*.ndjson.gz). " + "Use a narrow pattern to limit ingest time and cache size." + ), +) +_parser.add_argument( + "--max-len", + type=int, + default=512, + help="Sequence length ablation (e.g. 512 / 1024 / 2048 per proposal).", +) +_parser.add_argument( + "--no-mpf", + action="store_true", + help="Ablation: use generic CLS/REG specials instead of task MPF tokens.", +) +_parser.add_argument( + "--hidden-dim", + type=int, + default=128, + help="Embedding / hidden size (model ablation).", +) +_parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Adam learning rate (trainer.train optimizer_params).", +) +_parser.add_argument( + "--quick-test", + action="store_true", + help="Use synthetic in-memory FHIR lines only (no disk root).", +) +_parser.add_argument( + "--ablation", + action="store_true", + help="Run a small max_len × MPF × hidden_dim grid on synthetic data; print table.", +) +_parser.add_argument( + "--epochs", + type=int, + default=None, + help="Training epochs (default: 2 with --quick-test, else 20).", +) +_parser.add_argument( + "--max-patients", + type=int, + default=500, + help=( + "Fingerprint for cache dir: cap patients during normalization (-1 = full cohort, " + "match an uncapped NDJSON→flattened-table export)." + ), +) +_parser.add_argument( + "--prebuilt-global-event-dir", + type=str, + default=None, + help=( + "Directory with normalized flattened FHIR tables (*.parquet). Seeds " + "cache/flattened_tables/ so training skips NDJSON normalization " + "(downstream unchanged: Patient + MPF + Trainer)." + ), +) +_parser.add_argument( + "--ingest-num-shards", + type=int, + default=None, + help="Compatibility no-op: retained for CLI stability with older runs.", +) +_parser.add_argument( + "--train-patient-cap", + type=int, + default=None, + help=( + "After cache is ready, only build samples from a deterministic label-aware " + "patient subset of size N (reduces train time; unique-id scan of " + "global_event_df still runs once)." + ), +) +_parser.add_argument( + "--ntfy-url", + type=str, + default=None, + help="POST notification when main() finishes (e.g. https://ntfy.sh/topic).", +) +_parser.add_argument( + "--loss-plot-path", + type=str, + default=None, + help="Write loss curve PNG here (default: alongside Trainer log under output/).", +) +_parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="PyHealth dataset cache parent (UUID subdir added by MIMIC4FHIRDataset).", +) +_parser.add_argument( + "--task-num-workers", + type=int, + default=None, + help=( + "Workers for LitData task/processor transforms (default: dataset " + "``num_workers``, usually 1)." + ), +) +_pre_args, _ = _parser.parse_known_args() +if _pre_args.gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(_pre_args.gpu) + +import torch + +import polars as pl + +from pyhealth.datasets import MIMIC4FHIRDataset, get_dataloader +from pyhealth.datasets.fhir_cehr import infer_mortality_label +from pyhealth.models import EHRMambaCEHR +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask +from pyhealth.trainer import Trainer + + +class PatientCappedMPFTask(MPFClinicalPredictionTask): + """Example-only: limit task transform to an explicit patient_id allow-list.""" + + def __init__( + self, + *, + max_len: int, + use_mpf: bool, + patient_ids_allow: List[str], + ) -> None: + super().__init__(max_len=max_len, use_mpf=use_mpf) + self.patient_ids_allow = patient_ids_allow + + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id").is_in(self.patient_ids_allow)) + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +SEED = 42 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +BATCH_SIZE = 8 +EPOCHS = 20 +SPLIT_RATIOS = (0.7, 0.1, 0.2) + + +def _max_patients_arg(v: int) -> Optional[int]: + return None if v is not None and v < 0 else v + + +def _seed_flattened_table_cache(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None: + """Copy normalized per-resource parquet tables into the dataset cache.""" + + tables = sorted(prebuilt_dir.glob("*.parquet")) + if not tables: + raise FileNotFoundError( + f"No *.parquet tables under {prebuilt_dir} — expected flattened FHIR tables." + ) + prepared = ds.prepared_tables_dir + if prepared.exists() and any(prepared.glob("*.parquet")): + return + prepared.mkdir(parents=True, exist_ok=True) + for src in tables: + dest = prepared / src.name + if dest.exists(): + continue + try: + os.link(src, dest) + except OSError: + shutil.copy2(src, dest) + + +def _parse_train_losses_from_log(log_path: Path) -> List[float]: + """Mean training loss per epoch from Trainer file log.""" + + if not log_path.is_file(): + return [] + text = log_path.read_text(encoding="utf-8", errors="replace") + losses: List[float] = [] + lines = text.splitlines() + for i, line in enumerate(lines): + if "--- Train epoch-" in line and i + 1 < len(lines): + m = re.search(r"loss:\s*([0-9.eE+-]+)", lines[i + 1]) + if m: + losses.append(float(m.group(1))) + return losses + + +def _write_loss_plot(losses: List[float], out_path: Path) -> None: + if not losses: + return + out_path.parent.mkdir(parents=True, exist_ok=True) + try: + import matplotlib.pyplot as plt + except ImportError: + csv_path = out_path.with_suffix(".csv") + csv_path.write_text( + "epoch,train_loss_mean\n" + + "\n".join(f"{i},{v}" for i, v in enumerate(losses)), + encoding="utf-8", + ) + print( + "matplotlib not installed; wrote", csv_path, "(pip install matplotlib for PNG)" + ) + return + plt.figure(figsize=(6, 3.5)) + plt.plot(range(len(losses)), losses, marker="o", linewidth=1) + plt.xlabel("epoch") + plt.ylabel("mean train loss") + plt.title("EHRMambaCEHR training loss (MPF)") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(out_path, dpi=120) + plt.close() + print("loss plot:", out_path) + + +def _ntfy(url: str, title: str, message: str) -> None: + try: + req = urllib.request.Request( + url, + data=message.encode("utf-8"), + method="POST", + ) + req.add_header("Title", title[:200]) + with urllib.request.urlopen(req, timeout=60) as resp: + if resp.status >= 400: + print("ntfy HTTP", resp.status, file=sys.stderr) + except urllib.error.URLError as e: + print("ntfy failed:", e, file=sys.stderr) + + +def _quick_test_ndjson_dir() -> str: + """Write two-patient synthetic NDJSON; returns temp directory (caller cleans up).""" + + from pyhealth.datasets.fhir_ingest import synthetic_mpf_two_patient_ndjson_text + + tmp = tempfile.mkdtemp(prefix="pyhealth_mimic4_fhir_quick_") + Path(tmp, "fixture.ndjson").write_text( + synthetic_mpf_two_patient_ndjson_text(), + encoding="utf-8", + ) + return tmp + + +def _patient_label(ds: MIMIC4FHIRDataset, patient_id: str) -> int: + patient = ds.get_patient(patient_id) + return int(infer_mortality_label(patient)) + + +def _ensure_binary_label_coverage(ds: MIMIC4FHIRDataset) -> None: + found: Dict[int, str] = {} + scanned = 0 + for patient_id in ds.unique_patient_ids: + label = _patient_label(ds, patient_id) + scanned += 1 + found.setdefault(label, patient_id) + if len(found) == 2: + print( + "label preflight:", + {"scanned_patients": scanned, "example_patient_ids": found}, + ) + return + raise SystemExit( + "Binary mortality example found only one label in the available cohort; " + "cannot build a valid binary training set from this cache." + ) + + +def _select_patient_ids_for_cap( + ds: MIMIC4FHIRDataset, requested_cap: int +) -> List[str]: + patient_ids = ds.unique_patient_ids + if not patient_ids: + return [] + + desired = max(2, requested_cap) + desired = min(desired, len(patient_ids)) + if desired < requested_cap: + print( + f"train_patient_cap requested {requested_cap}, but only {desired} patients are available." + ) + elif requested_cap < 2: + print( + f"train_patient_cap={requested_cap} is too small for binary labels; using {desired}." + ) + + encountered: List[str] = [] + label_by_patient_id: Dict[str, int] = {} + first_by_label: Dict[int, str] = {} + for patient_id in patient_ids: + label = _patient_label(ds, patient_id) + encountered.append(patient_id) + label_by_patient_id[patient_id] = label + first_by_label.setdefault(label, patient_id) + if len(encountered) >= desired and len(first_by_label) == 2: + break + + if len(first_by_label) < 2: + raise SystemExit( + "Unable to satisfy --train-patient-cap with both binary labels from the " + "available cohort. Use a different cache/export or remove the cap." + ) + + selected = encountered[:desired] + selected_labels = {label_by_patient_id[pid] for pid in selected} + if len(selected_labels) == 1: + missing_label = 1 - next(iter(selected_labels)) + replacement = first_by_label[missing_label] + for idx in range(len(selected) - 1, -1, -1): + if label_by_patient_id[selected[idx]] != missing_label: + selected[idx] = replacement + break + + counts = { + 0: sum(1 for pid in selected if label_by_patient_id[pid] == 0), + 1: sum(1 for pid in selected if label_by_patient_id[pid] == 1), + } + print( + "train_patient_cap selection:", + { + "requested": requested_cap, + "selected": len(selected), + "scanned_patients": len(encountered), + "label_counts": counts, + }, + ) + return selected + + +def _sample_label(sample: Dict[str, Any]) -> int: + label = sample["label"] + if isinstance(label, torch.Tensor): + return int(label.reshape(-1)[0].item()) + return int(label) + + +def _split_counts(n: int) -> List[int]: + if n < 3: + raise ValueError("Need at least 3 samples for three-way stratified split.") + counts = [1, 1, 1] + remaining = n - 3 + raw = [ratio * remaining for ratio in SPLIT_RATIOS] + floors = [int(x) for x in raw] + for i, floor in enumerate(floors): + counts[i] += floor + assigned = sum(counts) + order = sorted( + range(3), + key=lambda i: raw[i] - floors[i], + reverse=True, + ) + for i in order: + if assigned >= n: + break + counts[i] += 1 + assigned += 1 + counts[0] += n - assigned + return counts + + +def _split_sample_dataset_for_binary_metrics(sample_ds: Any) -> tuple[Any, Any, Any]: + if len(sample_ds) < 8: + print("sample count < 8; reusing the full dataset for train/val/test.") + return sample_ds, sample_ds, sample_ds + + label_to_indices: Dict[int, List[int]] = {0: [], 1: []} + for idx in range(len(sample_ds)): + label_to_indices[_sample_label(sample_ds[idx])].append(idx) + + label_counts = {label: len(indices) for label, indices in label_to_indices.items()} + min_count = min(label_counts.values()) + if min_count < 3: + print( + "label distribution too small for disjoint binary train/val/test splits; " + "reusing the full dataset for train/val/test.", + label_counts, + ) + return sample_ds, sample_ds, sample_ds + + rng = random.Random(SEED) + split_indices: List[List[int]] = [[], [], []] + for indices in label_to_indices.values(): + shuffled = indices[:] + rng.shuffle(shuffled) + n_train, n_val, n_test = _split_counts(len(shuffled)) + split_indices[0].extend(shuffled[:n_train]) + split_indices[1].extend(shuffled[n_train : n_train + n_val]) + split_indices[2].extend(shuffled[n_train + n_val : n_train + n_val + n_test]) + + for indices in split_indices: + indices.sort() + + split_counts = [] + for indices in split_indices: + split_counts.append( + { + 0: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 0), + 1: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 1), + "n": len(indices), + } + ) + print( + "binary stratified split counts:", + {"train": split_counts[0], "val": split_counts[1], "test": split_counts[2]}, + ) + return ( + sample_ds.subset(split_indices[0]), + sample_ds.subset(split_indices[1]), + sample_ds.subset(split_indices[2]), + ) + + +def _build_loaders_from_sample_dataset( + sample_ds: Any, + vocab_size: int, +) -> tuple[Any, Any, Any, Any, int]: + train_ds, val_ds, test_ds = _split_sample_dataset_for_binary_metrics(sample_ds) + train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False) + return sample_ds, train_loader, val_loader, test_loader, vocab_size + + +def run_single_train( + *, + fhir_root: str, + max_len: int, + use_mpf: bool, + hidden_dim: int, + epochs: int, + lr: float = 1e-3, + glob_pattern: str = "*.ndjson", + cache_dir: Optional[str] = None, + dataset_max_patients: Optional[int] = 500, + ingest_num_shards: Optional[int] = None, + prebuilt_global_event_dir: Optional[str] = None, + train_patient_cap: Optional[int] = None, +) -> Dict[str, float]: + """Train/eval one configuration; returns test metrics (floats).""" + + ds_kw: Dict[str, Any] = { + "root": fhir_root, + "glob_pattern": glob_pattern, + "cache_dir": cache_dir, + "max_patients": dataset_max_patients, + } + if ingest_num_shards is not None: + ds_kw["ingest_num_shards"] = ingest_num_shards + ds = MIMIC4FHIRDataset(**ds_kw) + if prebuilt_global_event_dir: + _seed_flattened_table_cache( + Path(prebuilt_global_event_dir).expanduser().resolve(), ds + ) + if train_patient_cap is not None: + allow = _select_patient_ids_for_cap(ds, train_patient_cap) + task: MPFClinicalPredictionTask = PatientCappedMPFTask( + max_len=max_len, + use_mpf=use_mpf, + patient_ids_allow=allow, + ) + else: + _ensure_binary_label_coverage(ds) + task = MPFClinicalPredictionTask(max_len=max_len, use_mpf=use_mpf) + sample_ds = ds.set_task(task, num_workers=1) + vocab_size = ds.vocab.vocab_size + sample_ds, train_l, val_l, test_l, vocab_size = _build_loaders_from_sample_dataset( + sample_ds, vocab_size + ) + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=hidden_dim, + num_layers=2, + dropout=0.1, + ) + trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE) + trainer.train( + train_dataloader=train_l, + val_dataloader=val_l, + epochs=epochs, + monitor="roc_auc", + optimizer_params={"lr": lr}, + ) + results = trainer.evaluate(test_l) + return {k: float(v) for k, v in results.items()} + + +def run_ablation_table(*, lr: float = 1e-3) -> None: + """Task × model grid on synthetic NDJSON (short runs for comparison).""" + + # Ablations: context length, MPF vs CLS/REG, plus one hidden_dim pair. + grid = [ + (32, True, 32), + (32, False, 32), + (96, True, 64), + (96, False, 64), + ] + tmp = _quick_test_ndjson_dir() + try: + print( + "Ablation (synthetic, 1 epoch each): max_len, use_mpf, hidden_dim, lr=" + f"{lr} -> test roc_auc, pr_auc" + ) + rows = [] + t0 = time.perf_counter() + for max_len, use_mpf, hidden_dim in grid: + metrics = run_single_train( + fhir_root=tmp, + max_len=max_len, + use_mpf=use_mpf, + hidden_dim=hidden_dim, + epochs=1, + lr=lr, + cache_dir=tmp, + dataset_max_patients=500, + ) + rows.append((max_len, use_mpf, hidden_dim, metrics)) + print( + f" max_len={max_len} mpf={use_mpf} hid={hidden_dim} -> " + f"roc_auc={metrics['roc_auc']:.4f} pr_auc={metrics['pr_auc']:.4f}" + ) + print("ablation_wall_s:", round(time.perf_counter() - t0, 2)) + best = max(rows, key=lambda r: r[3]["roc_auc"]) + print( + "best_by_roc_auc:", + { + "max_len": best[0], + "use_mpf": best[1], + "hidden_dim": best[2], + "metrics": best[3], + }, + ) + except Exception: + print(f"ablation: leaving scratch directory for debugging: {tmp}", file=sys.stderr) + raise + else: + shutil.rmtree(tmp, ignore_errors=True) + + +def main() -> None: + args = _parser.parse_args() + status = "abort" + ntfy_detail = "" + try: + _main_train(args) + status = "ok" + ntfy_detail = "Training finished successfully." + except SystemExit as e: + status = "exit" + ntfy_detail = f"SystemExit {e.code!r}" + raise + except Exception as e: + status = "error" + ntfy_detail = f"{type(e).__name__}: {e}"[:3800] + raise + finally: + if args.ntfy_url and status in ("ok", "error"): + _ntfy( + args.ntfy_url, + "mimic-fhir-train OK" if status == "ok" else "mimic-fhir-train FAIL", + ntfy_detail, + ) + + +def _main_train(args: argparse.Namespace) -> None: + fhir_root = args.fhir_root or os.environ.get("MIMIC4_FHIR_ROOT") + quick = args.quick_test + quick_test_tmp: Optional[str] = None + if args.epochs is not None: + epochs = args.epochs + else: + epochs = 2 if quick else EPOCHS + + if args.ablation: + if not quick: + raise SystemExit("--ablation requires --quick-test (synthetic data only).") + run_ablation_table(lr=args.lr) + return + + print("EHRMambaCEHR – MIMIC-IV FHIR (MPF clinical prediction)") + print("device:", DEVICE) + print("max_len:", args.max_len, "| use_mpf:", not args.no_mpf) + print("hidden_dim:", args.hidden_dim, "| lr:", args.lr) + + sample_ds: Any + vocab: Any + + if quick: + quick_test_tmp = _quick_test_ndjson_dir() + ds = MIMIC4FHIRDataset( + root=quick_test_tmp, + glob_pattern="*.ndjson", + cache_dir=quick_test_tmp, + max_patients=500, + ) + try: + print( + "pipeline: synthetic NDJSON → flattened tables → global_event_df " + "→ set_task → SampleDataset → Trainer" + ) + task = MPFClinicalPredictionTask( + max_len=args.max_len, + use_mpf=not args.no_mpf, + ) + print("set_task (quick-test, num_workers=1)...") + t_task0 = time.perf_counter() + sample_ds = ds.set_task(task, num_workers=1) + print( + "set_task done: n_samples=", + len(sample_ds), + "wall_s=", + round(time.perf_counter() - t_task0, 2), + ) + vocab = ds.vocab + except Exception: + print( + f"quick-test: leaving NDJSON/Parquet scratch at {quick_test_tmp}", + file=sys.stderr, + ) + raise + else: + mp = _max_patients_arg(args.max_patients) + if not fhir_root or not os.path.isdir(fhir_root): + raise SystemExit( + "Set MIMIC4_FHIR_ROOT or pass --fhir-root to an existing directory " + "(NDJSON tree for ingest fingerprint, even when using --prebuilt-global-event-dir)." + ) + ds_kw: Dict[str, Any] = { + "root": fhir_root, + "max_patients": mp, + "cache_dir": args.cache_dir, + } + if args.glob_pattern is not None: + ds_kw["glob_pattern"] = args.glob_pattern + if args.ingest_num_shards is not None: + ds_kw["ingest_num_shards"] = args.ingest_num_shards + ds = MIMIC4FHIRDataset(**ds_kw) + if args.prebuilt_global_event_dir: + pb = Path(args.prebuilt_global_event_dir).expanduser().resolve() + if not pb.is_dir(): + raise SystemExit(f"--prebuilt-global-event-dir not a directory: {pb}") + print( + "pipeline: offline flattened FHIR tables → seed flattened table cache " + "→ global_event_df → set_task → SampleDataset → Trainer " + "(no NDJSON normalization)" + ) + _seed_flattened_table_cache(pb, ds) + else: + print( + "pipeline: NDJSON root → MIMIC4FHIRDataset flattening → global_event_df " + "→ set_task → SampleDataset → Trainer" + ) + print("glob_pattern:", ds.glob_pattern, "| max_patients fingerprint:", mp) + if args.train_patient_cap is not None: + print("train_patient_cap:", args.train_patient_cap) + allow = _select_patient_ids_for_cap(ds, args.train_patient_cap) + mpf_task: MPFClinicalPredictionTask = PatientCappedMPFTask( + max_len=args.max_len, + use_mpf=not args.no_mpf, + patient_ids_allow=allow, + ) + print("task patient allow-list size:", len(allow)) + else: + _ensure_binary_label_coverage(ds) + mpf_task = MPFClinicalPredictionTask( + max_len=args.max_len, + use_mpf=not args.no_mpf, + ) + nw = args.task_num_workers + if nw is None: + nw = ds.num_workers + print(f"set_task (LitData task cache, num_workers={nw})...") + t_task0 = time.perf_counter() + sample_ds = ds.set_task(mpf_task, num_workers=nw) + print( + "set_task done: n_samples=", + len(sample_ds), + "wall_s=", + round(time.perf_counter() - t_task0, 2), + ) + vocab = ds.vocab + print("fhir_root:", fhir_root) + + try: + if len(sample_ds) == 0: + raise SystemExit( + "No training samples (0 patients or empty sequences). " + "PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (see default glob_patterns in " + "pyhealth/datasets/configs/mimic4_fhir.yaml). If your tree is plain *.ndjson, " + "construct MIMIC4FHIRDataset with glob_pattern='**/*.ndjson'." + ) + + sample_ds, train_loader, val_loader, test_loader, vocab_size = ( + _build_loaders_from_sample_dataset(sample_ds, vocab.vocab_size) + ) + + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=args.hidden_dim, + num_layers=2, + dropout=0.1, + ) + trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE) + + t0 = time.perf_counter() + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="roc_auc", + optimizer_params={"lr": args.lr}, + ) + results = trainer.evaluate(test_loader) + print("Test:", {k: float(v) for k, v in results.items()}) + print("wall_s:", round(time.perf_counter() - t0, 1)) + print("concept_vocab_size:", vocab.vocab_size) + + log_txt = ( + Path(trainer.exp_path) / "log.txt" if trainer.exp_path else None + ) + if log_txt and log_txt.is_file(): + losses = _parse_train_losses_from_log(log_txt) + print("train_loss_per_epoch:", losses) + plot_path = ( + Path(args.loss_plot_path) + if args.loss_plot_path + else Path(trainer.exp_path) / "train_loss.png" + ) + if trainer.exp_path: + _write_loss_plot(losses, plot_path) + finally: + if quick_test_tmp is not None: + shutil.rmtree(quick_test_tmp, ignore_errors=True) + + +if __name__ == "__main__": + main() diff --git a/pixi.lock b/pixi.lock index 0f11d28d7..26c1b2237 100644 --- a/pixi.lock +++ b/pixi.lock @@ -96,6 +96,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -217,6 +218,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl @@ -328,6 +330,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl @@ -440,6 +443,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl @@ -597,6 +601,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -727,6 +732,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -848,6 +854,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -970,6 +977,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1158,6 +1166,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1324,6 +1333,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1474,6 +1484,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1626,6 +1637,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1777,6 +1789,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1906,6 +1919,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -2026,6 +2040,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -2147,6 +2162,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -2224,6 +2240,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2240,6 +2257,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/75/b4/b96bb66f6f8cc4669de44a158099b249c8159231d254ab6b092909388be5/fonttools-4.59.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl @@ -2269,6 +2287,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl @@ -2286,6 +2305,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -2308,6 +2328,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/34/43/3f250ec28edff1c06ffaa25faddbe13ae85c11a9724894cbdcf89427de78/rdkit-2025.3.3-cp313-cp313-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/db/60/1eeca2074f5b87df394fccaa432ae3fc06c9c9bfa97c5051aed70e6e00c2/regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/99/72/c86a4cd867816350fe8dee13f30222340b9cd6b96173955819a5561810c5/scikit_learn-1.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl @@ -2360,6 +2381,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/readline-8.2-h8382b9d_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/tk-8.6.13-noxft_h5688188_102.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2376,6 +2398,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/57/7969af50b26408be12baa317c6147588db5b38af2759e6df94554dbc5fdb/fonttools-4.59.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl @@ -2405,9 +2428,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/d3/68/93180dce57f684a61a88a45ed13047558ded2be46f03acb8dec6d7c513af/msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl @@ -2430,6 +2455,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/ff/5f/907a48c5f9b83302b4530605df1325963977fdf06753d3d8610d16c40197/rdkit-2025.3.3-cp313-cp313-manylinux_2_28_aarch64.whl - pypi: https://files.pythonhosted.org/packages/fc/fd/37868b75eaf63843165f1d2122ca6cb94bfc0271e4428cf58c0616786dce/regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/e8/66/277967b29bd297538dc7a6ecfb1a7dce751beabd0d7f7a2233be7a4f7832/scikit_learn-1.7.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl @@ -2472,6 +2498,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.2-h1d1bf99_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h892fb3f_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2488,6 +2515,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f3/bb/390990e7c457d377b00890d9f96a3ca13ae2517efafb6609c1756e213ba4/fonttools-4.59.0-cp313-cp313-macosx_10_13_universal2.whl @@ -2517,9 +2545,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl @@ -2542,6 +2572,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/3b/0b/6ab0cc692b2890f4f7c74f6ffd4bba748dcb9312d5a7bd2328cb82204da1/rdkit-2025.3.3-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/09/c9/4e68181a4a652fb3ef5099e077faf4fd2a694ea6e0f806a7737aff9e758a/regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/3b/11f1b4a2f5d2ab7da34ecc062b0bc301f2be024d110a6466726bec8c055c/safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/f3/f1df377d1bdfc3e3e2adc9c119c238b182293e6740df4cbeac6de2cc3e23/scikit_learn-1.7.1-cp313-cp313-macosx_12_0_arm64.whl @@ -2585,6 +2616,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/vc-14.3-h41ae7f8_26.conda - conda: https://conda.anaconda.org/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2602,6 +2634,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a0/ee/f626cd372932d828508137a79b85167fdcf3adab2e3bed433f295c596c6a/fonttools-4.59.0-cp313-cp313-win_amd64.whl @@ -2630,9 +2663,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/74/07/1ed8277f8653c40ebc65985180b007879f6a836c525b3885dcc6448ae6cb/msgpack-1.1.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl @@ -2655,6 +2690,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/98/da/164e31b607c0cf22f1179cd15fa058780f940b21ec42ba3c9026c21897e3/rdkit-2025.3.3-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/45/94/bc295babb3062a731f52621cdc992d123111282e291abaf23faa413443ea/regex-2024.11.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/e2/47/9291cfa1db1dae9880420d1e07dbc7e8dd4a7cdbc42eaba22512e6bde958/scikit_learn-1.7.1-cp313-cp313-win_amd64.whl @@ -2776,6 +2812,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -2897,6 +2934,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl @@ -3008,6 +3046,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl @@ -3120,6 +3159,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl @@ -3213,6 +3253,11 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 +- pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + name: absl-py + version: 2.4.0 + sha256: 88476fd881ca8aab94ffa78b7b6c632a782ab3ba1cd19c9bd423abc4fb4cd28d + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl name: accelerate version: 1.10.0 @@ -3958,6 +4003,11 @@ packages: - pkg:pypi/editables?source=hash-mapping size: 10828 timestamp: 1733208220327 +- pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz + name: editdistance + version: 0.8.1 + sha256: d1cdf80a5d5014b0c9126a69a42ce55a457b457f6986ff69ca98e4fe4d2d8fed + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl name: einops version: 0.8.2 @@ -5913,6 +5963,32 @@ packages: - pkg:pypi/nh3?source=hash-mapping size: 584955 timestamp: 1756737407424 +- pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl + name: nltk + version: 3.9.4 + sha256: f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f + requires_dist: + - click + - joblib + - regex>=2021.8.3 + - tqdm + - numpy ; extra == 'machine-learning' + - python-crfsuite ; extra == 'machine-learning' + - scikit-learn ; extra == 'machine-learning' + - scipy ; extra == 'machine-learning' + - matplotlib ; extra == 'plot' + - pyparsing ; extra == 'tgrep' + - twython ; extra == 'twitter' + - requests ; extra == 'corenlp' + - scipy ; extra == 'all' + - python-crfsuite ; extra == 'all' + - pyparsing ; extra == 'all' + - requests ; extra == 'all' + - numpy ; extra == 'all' + - scikit-learn ; extra == 'all' + - twython ; extra == 'all' + - matplotlib ; extra == 'all' + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl name: numpy version: 2.2.6 @@ -6100,6 +6176,26 @@ packages: purls: [] size: 9327033 timestamp: 1751392489008 +- pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl + name: orjson + version: 3.11.7 + sha256: b9f95dcdea9d4f805daa9ddf02617a89e484c6985fa03055459f90e87d7a0757 + requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl + name: orjson + version: 3.11.7 + sha256: 814be4b49b228cfc0b3c565acf642dd7d13538f966e3ccde61f4f55be3e20785 + requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl + name: orjson + version: 3.11.7 + sha256: 1d98b30cc1313d52d4af17d9c3d307b08389752ec5f2e5febdfada70b0f8c733 + requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + name: orjson + version: 3.11.7 + sha256: a12b80df61aab7b98b490fe9e4879925ba666fccdfcd175252ce4d9035865ace + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl name: outdated version: 0.2.2 @@ -7030,7 +7126,7 @@ packages: - pypi: ./ name: pyhealth version: 2.0.0 - sha256: f07719f9dceb759c35507216c8033d2f915d241418d4fad2ab51b37c0e73260f + sha256: a11efbf6fe99193820e0ee48e843cfb89526030b1d5e48ddc71c8fa543242572 requires_dist: - torch~=2.7.1 - torchvision @@ -7055,6 +7151,11 @@ packages: - more-itertools~=10.8.0 - einops>=0.8.0 - linear-attention-transformer>=0.19.1 + - orjson~=3.10 + - torch-geometric>=2.6.0 ; extra == 'graph' + - editdistance~=0.8.1 ; extra == 'nlp' + - rouge-score~=0.1.2 ; extra == 'nlp' + - nltk~=3.9.1 ; extra == 'nlp' requires_python: '>=3.12,<3.14' - pypi: https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl name: pyparsing @@ -7416,6 +7517,16 @@ packages: - pkg:pypi/rich?source=compressed-mapping size: 201098 timestamp: 1753436991345 +- pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz + name: rouge-score + version: 0.1.2 + sha256: c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04 + requires_dist: + - absl-py + - nltk + - numpy + - six>=1.14.0 + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl name: s3transfer version: 0.16.0 diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..71b0d5b67 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -59,6 +59,8 @@ def __init__(self, *args, **kwargs): from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset +from .fhir_cehr import ConceptVocab +from .mimic4_fhir import MIMIC4FHIRDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset diff --git a/pyhealth/datasets/configs/mimic4_fhir.yaml b/pyhealth/datasets/configs/mimic4_fhir.yaml new file mode 100644 index 000000000..989ab79cf --- /dev/null +++ b/pyhealth/datasets/configs/mimic4_fhir.yaml @@ -0,0 +1,118 @@ +# MIMIC-IV FHIR Resource Flattening Configuration +# ================================================ +# +# This YAML defines the normalized schema for MIMIC-IV FHIR exports after streaming +# ingestion. Raw NDJSON/NDJSON.GZ resources are parsed and flattened into six +# per-resource-type Parquet tables (see ``version`` below), then loaded through the +# standard BaseDataset pipeline for task construction. +# +# For ingest details, see the docstring of ``stream_fhir_ndjson_to_flat_tables()`` +# in ``pyhealth/datasets/mimic4_fhir.py``. + +version: "fhir_r4_flattened" + +# Glob Pattern(s) for NDJSON File Discovery +# =========================================== +# +# ``glob_patterns`` (list) or ``glob_pattern`` (string): Patterns to match NDJSON files +# under the ingest root directory. Patterns are applied via pathlib.Path.glob(). +# +# Default: Six targeted patterns matching PhysioNet MIMIC-IV FHIR Mimic* shard families +# that map to flattened tables. This avoids decompressing and parsing ~10% of PhysioNet +# exports (MedicationAdministration, Specimen, Organization, …) that are skipped by +# the flattener. +# +# Alternatives: +# - For non-PhysioNet naming, use a single broad pattern: +# glob_pattern: "**/*.ndjson.gz" +# - To test on a subset, use a narrower list: +# glob_patterns: +# - "**/MimicPatient*.ndjson.gz" +# - "**/MimicObservation*.ndjson.gz" +# +# Notes: +# - Patterns use ``**/`` for recursive search (works in both flat and nested layouts). +# - Can be overridden at runtime via MIMIC4FHIRDataset(glob_pattern=...) or +# MIMIC4FHIRDataset(glob_patterns=[...]). + +glob_patterns: + - "**/MimicPatient*.ndjson.gz" + - "**/MimicEncounter*.ndjson.gz" + - "**/MimicCondition*.ndjson.gz" + - "**/MimicObservation*.ndjson.gz" + - "**/MimicMedicationRequest*.ndjson.gz" + - "**/MimicProcedure*.ndjson.gz" + +# Flattened Table Schema +# ====================== +# +# Each table is normalized from a single FHIR resource type. Columns are: +# - patient_id (str): Foreign key to patient (derived from subject.reference or id). +# - [timestamp] (str): ISO 8601 datetime string (coerced; nullable). +# - attributes (List[str]): Additional columns from the resource. +# +# Unsupported resource types (Medication, MedicationAdministration, Specimen, …) +# are silently dropped during ingest; only tables listed here are written. + +tables: + patient: + file_path: "patient.parquet" + patient_id: "patient_id" + timestamp: "birth_date" + attributes: + - "patient_fhir_id" + - "birth_date" + - "gender" + - "deceased_boolean" + - "deceased_datetime" + + encounter: + file_path: "encounter.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "encounter_class" + - "encounter_end" + + condition: + file_path: "condition.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + observation: + file_path: "observation.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + medication_request: + file_path: "medication_request.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + procedure: + file_path: "procedure.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" diff --git a/pyhealth/datasets/fhir_cehr.py b/pyhealth/datasets/fhir_cehr.py new file mode 100644 index 000000000..67ae35c5e --- /dev/null +++ b/pyhealth/datasets/fhir_cehr.py @@ -0,0 +1,364 @@ +"""CEHR-style tokenization, vocabulary, and sequence building for FHIR timelines. + +Key public API +-------------- +ConceptVocab + Token-to-dense-id mapping with PAD/UNK reserved at 0 and 1. JSON-serializable. + +build_cehr_sequences(patient, vocab, max_len) + Flatten a Patient's tabular FHIR rows into CEHR-aligned feature lists. + +infer_mortality_label(patient) + Heuristic binary mortality label from flattened patient rows. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import orjson + +from pyhealth.data import Patient + +from .fhir_ingest import as_naive, parse_dt + +DEFAULT_PAD = 0 +DEFAULT_UNK = 1 + +__all__ = [ + # Constants + "DEFAULT_PAD", + "DEFAULT_UNK", + "EVENT_TYPE_TO_TOKEN_TYPE", + # Vocabulary + "ConceptVocab", + "ensure_special_tokens", + # Sequence building + "collect_cehr_timeline_events", + "warm_mpf_vocab_from_patient", + "build_cehr_sequences", + # Labels + "infer_mortality_label", +] + +EVENT_TYPE_TO_TOKEN_TYPE = { + "encounter": 1, + "condition": 2, + "medication_request": 3, + "observation": 4, + "procedure": 5, +} + +# Table-driven lookups for flattened event-row column access. +_CONCEPT_KEY_COL: Dict[str, str] = { + "condition": "condition/concept_key", + "observation": "observation/concept_key", + "medication_request": "medication_request/concept_key", + "procedure": "procedure/concept_key", +} + +_ENCOUNTER_ID_COL: Dict[str, str] = { + "condition": "condition/encounter_id", + "observation": "observation/encounter_id", + "medication_request": "medication_request/encounter_id", + "procedure": "procedure/encounter_id", + "encounter": "encounter/encounter_id", +} + +# --------------------------------------------------------------------------- +# ConceptVocab +# --------------------------------------------------------------------------- + + +@dataclass +class ConceptVocab: + """Maps concept keys to dense ids with PAD/UNK reserved at 0 and 1.""" + + token_to_id: Dict[str, int] = field(default_factory=dict) + pad_id: int = DEFAULT_PAD + unk_id: int = DEFAULT_UNK + _next_id: int = 2 + + def __post_init__(self) -> None: + if not self.token_to_id: + self.token_to_id = {"": self.pad_id, "": self.unk_id} + self._next_id = 2 + + def add_token(self, key: str) -> int: + if key in self.token_to_id: + return self.token_to_id[key] + tid = self._next_id + self.token_to_id[key] = tid + self._next_id += 1 + return tid + + def __getitem__(self, key: str) -> int: + return self.token_to_id.get(key, self.unk_id) + + @property + def vocab_size(self) -> int: + return self._next_id + + def to_json(self) -> Dict[str, Any]: + return {"token_to_id": self.token_to_id, "next_id": self._next_id} + + @classmethod + def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab": + vocab = cls() + loaded = dict(data.get("token_to_id") or {}) + if not loaded: + vocab._next_id = int(data.get("next_id", 2)) + return vocab + vocab.token_to_id = loaded + vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1)) + return vocab + + def save(self, path: str) -> None: + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_bytes(orjson.dumps(self.to_json(), option=orjson.OPT_SORT_KEYS)) + + @classmethod + def load(cls, path: str) -> "ConceptVocab": + return cls.from_json(orjson.loads(Path(path).read_bytes())) + + +def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]: + """Add EHRMamba/CEHR special tokens and return their ids.""" + return {name: vocab.add_token(name) for name in ("", "", "", "")} + + +# --------------------------------------------------------------------------- +# Row utilities for flattened event stream +# --------------------------------------------------------------------------- + + +def _clean_string(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + return value.strip() or None + return str(value) + + +def _deceased_boolean_column_means_dead(value: Any) -> bool: + """True only for an explicit affirmative stored flag (not Python truthiness).""" + s = _clean_string(value) + return s is not None and s.lower() == "true" + + +def _row_datetime(value: Any) -> Optional[datetime]: + if value is None: + return None + if isinstance(value, datetime): + return as_naive(value) + try: + return parse_dt(str(value)) + except Exception: + return None + + +def _concept_key_from_row(row: Dict[str, Any]) -> str: + event_type = row.get("event_type") + col = _CONCEPT_KEY_COL.get(event_type) + if col: + return _clean_string(row.get(col)) or f"{event_type}|unknown" + if event_type == "encounter": + enc_class = _clean_string(row.get("encounter/encounter_class")) + return f"encounter|{enc_class}" if enc_class else "encounter|unknown" + return f"{event_type or 'event'}|unknown" + + +def _linked_encounter_id_from_row(row: Dict[str, Any]) -> Optional[str]: + col = _ENCOUNTER_ID_COL.get(row.get("event_type")) + return _clean_string(row.get(col)) if col else None + + +def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]: + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") != "patient": + continue + birth = _row_datetime(row.get("timestamp")) + if birth is not None: + return birth + raw = _clean_string(row.get("patient/birth_date")) + if raw: + return parse_dt(raw) + return None + + +def _sequential_visit_idx_for_time( + event_time: Optional[datetime], + visit_encounters: List[Tuple[datetime, int]], +) -> int: + if not visit_encounters: + return 0 + if event_time is None: + return visit_encounters[-1][1] + event_time = as_naive(event_time) + chosen = visit_encounters[0][1] + for encounter_start, visit_idx in visit_encounters: + if encounter_start <= event_time: + chosen = visit_idx + else: + break + return chosen + + +# --------------------------------------------------------------------------- +# CEHR timeline and sequence building +# --------------------------------------------------------------------------- + + +def collect_cehr_timeline_events( + patient: Patient, +) -> List[Tuple[datetime, str, str, int]]: + """Collect (time, concept_key, event_type, visit_idx) tuples from a patient's rows.""" + rows = list( + patient.data_source.sort(["timestamp", "event_type"], nulls_last=True).iter_rows(named=True) + ) + + encounter_rows: List[Tuple[datetime, str]] = [] + for row in rows: + if row.get("event_type") != "encounter": + continue + enc_id = _linked_encounter_id_from_row(row) + enc_start = _row_datetime(row.get("timestamp")) + if enc_id is not None and enc_start is not None: + encounter_rows.append((enc_start, enc_id)) + + encounter_rows.sort(key=lambda pair: pair[0]) + encounter_visit_idx = {enc_id: idx for idx, (_, enc_id) in enumerate(encounter_rows)} + encounter_start_by_id = {enc_id: enc_start for enc_start, enc_id in encounter_rows} + visit_encounters = [(enc_start, idx) for idx, (enc_start, _) in enumerate(encounter_rows)] + + events: List[Tuple[datetime, str, str, int]] = [] + unlinked: List[Tuple[Optional[datetime], str, str]] = [] + + for row in rows: + event_type = row.get("event_type") + if event_type not in EVENT_TYPE_TO_TOKEN_TYPE: + continue + event_time = _row_datetime(row.get("timestamp")) + concept_key = _concept_key_from_row(row) + + if event_type == "encounter": + enc_id = _linked_encounter_id_from_row(row) + if enc_id is None or event_time is None: + continue + visit_idx = encounter_visit_idx.get(enc_id) + if visit_idx is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + continue + + enc_id = _linked_encounter_id_from_row(row) + if enc_id and enc_id in encounter_visit_idx: + visit_idx = encounter_visit_idx[enc_id] + if event_time is None: + event_time = encounter_start_by_id.get(enc_id) + if event_time is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + else: + unlinked.append((event_time, concept_key, event_type)) + + for event_time, concept_key, event_type in unlinked: + visit_idx = _sequential_visit_idx_for_time(event_time, visit_encounters) + if event_time is None: + if not visit_encounters: + continue + for enc_start, enc_visit_idx in visit_encounters: + if enc_visit_idx == visit_idx: + event_time = enc_start + break + else: + event_time = visit_encounters[-1][0] + if event_time is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + + events.sort(key=lambda item: item[0]) + return events + + +def warm_mpf_vocab_from_patient( + vocab: ConceptVocab, + patient: Patient, + clinical_cap: int, +) -> None: + """Add concept keys from the last clinical_cap events of a patient to vocab.""" + tail = collect_cehr_timeline_events(patient)[-clinical_cap:] if clinical_cap > 0 else [] + for _, concept_key, _, _ in tail: + vocab.add_token(concept_key) + + +def build_cehr_sequences( + patient: Patient, + vocab: ConceptVocab, + max_len: int, + *, + base_time: Optional[datetime] = None, + grow_vocab: bool = True, +) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]: + """Flatten a patient's tabular FHIR rows into CEHR-aligned feature lists.""" + events = collect_cehr_timeline_events(patient) + birth = _birth_datetime_from_patient(patient) + + if base_time is None: + base_time = events[0][0] if events else datetime.now() + base_time = as_naive(base_time) + birth = as_naive(birth) + + concept_ids: List[int] = [] + token_types: List[int] = [] + time_stamps: List[float] = [] + ages: List[float] = [] + visit_orders: List[int] = [] + visit_segments: List[int] = [] + + for event_time, concept_key, event_type, visit_idx in (events[-max_len:] if max_len > 0 else []): + event_time = as_naive(event_time) + concept_id = vocab.add_token(concept_key) if grow_vocab else vocab[concept_key] + token_type = EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0) + time_delta = ( + float((event_time - base_time).total_seconds()) + if base_time is not None and event_time is not None + else 0.0 + ) + age_years = ( + (event_time - birth).days / 365.25 + if birth is not None and event_time is not None + else 0.0 + ) + concept_ids.append(concept_id) + token_types.append(token_type) + time_stamps.append(time_delta) + ages.append(age_years) + visit_orders.append(min(visit_idx, 511)) + visit_segments.append(visit_idx % 2) + + return concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments + + +# --------------------------------------------------------------------------- +# Label inference +# --------------------------------------------------------------------------- + + +def infer_mortality_label(patient: Patient) -> int: + """Heuristic binary mortality label from flattened patient rows.""" + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") == "patient": + if _deceased_boolean_column_means_dead(row.get("patient/deceased_boolean")): + return 1 + if _clean_string(row.get("patient/deceased_datetime")): + return 1 + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") == "condition": + key = (_clean_string(row.get("condition/concept_key")) or "").lower() + if any(token in key for token in ("death", "deceased", "mortality")): + return 1 + return 0 diff --git a/pyhealth/datasets/fhir_ingest.py b/pyhealth/datasets/fhir_ingest.py new file mode 100644 index 000000000..e4926e58d --- /dev/null +++ b/pyhealth/datasets/fhir_ingest.py @@ -0,0 +1,500 @@ +"""FHIR NDJSON parsing, flattening, and Parquet table writing. + +Key public API +-------------- +stream_fhir_ndjson_to_flat_tables(root, glob_pattern, out_dir) + Stream all matching NDJSON/NDJSON.GZ resources into six per-type Parquet tables. + +sorted_ndjson_files(root, glob_pattern) + List matching NDJSON files under root (deduplicated, sorted). + +filter_flat_tables_by_patient_ids(source_dir, out_dir, keep_ids) + Subset existing flattened tables to a specific patient cohort. + +synthetic_mpf_two_patient_ndjson_text() + In-memory two-patient fixture for tests and quick-test mode. +""" + +from __future__ import annotations + +import gzip +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple + +import orjson +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq + +GlobPatternArg = str | Sequence[str] +"""Single glob string or sequence of strings for NDJSON file discovery.""" + +__all__ = [ + # Types + "GlobPatternArg", + # Constants + "FHIR_SCHEMA_VERSION", + "FHIR_TABLES", + "FHIR_TABLES_FOR_PATIENT_IDS", + "FHIR_TABLE_FILE_NAMES", + "FHIR_TABLE_COLUMNS", + # Datetime helpers (also used by fhir_cehr) + "parse_dt", + "as_naive", + # FHIR iteration + "iter_ndjson_objects", + "iter_resources_from_ndjson_obj", + # Resource extraction + "patient_id_for_resource", + # Pipeline + "sorted_ndjson_files", + "stream_fhir_ndjson_to_flat_tables", + "filter_flat_tables_by_patient_ids", + # Synthetic fixtures + "synthetic_mpf_one_patient_resources", + "synthetic_mpf_two_patient_resources", + "synthetic_mpf_one_patient_ndjson_text", + "synthetic_mpf_two_patient_ndjson_text", +] + +FHIR_SCHEMA_VERSION = 3 + +FHIR_TABLES: List[str] = [ + "patient", + "encounter", + "condition", + "observation", + "medication_request", + "procedure", +] + +FHIR_TABLES_FOR_PATIENT_IDS: List[str] = [t for t in FHIR_TABLES if t != "patient"] + +FHIR_TABLE_FILE_NAMES: Dict[str, str] = {t: f"{t}.parquet" for t in FHIR_TABLES} + +FHIR_TABLE_COLUMNS: Dict[str, List[str]] = { + "patient": ["patient_id", "patient_fhir_id", "birth_date", "gender", "deceased_boolean", "deceased_datetime"], + "encounter": ["patient_id", "resource_id", "encounter_id", "event_time", "encounter_class", "encounter_end"], + "condition": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], + "observation": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], + "medication_request": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], + "procedure": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], +} + +# --------------------------------------------------------------------------- +# Datetime helpers (also imported by fhir_cehr) +# --------------------------------------------------------------------------- + + +def parse_dt(s: Optional[str]) -> Optional[datetime]: + if not s: + return None + try: + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + except ValueError: + dt = None + if dt is None and len(s) >= 10: + try: + dt = datetime.strptime(s[:10], "%Y-%m-%d") + except ValueError: + return None + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +def as_naive(dt: Optional[datetime]) -> Optional[datetime]: + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +# --------------------------------------------------------------------------- +# FHIR JSON helpers +# --------------------------------------------------------------------------- + + +def _coding_key(coding: Dict[str, Any]) -> str: + return f"{coding.get('system') or 'unknown'}|{coding.get('code') or 'unknown'}" + + +def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]: + if not obj: + return None + codings = obj.get("coding") or [] + if not codings and "concept" in obj: + codings = (obj.get("concept") or {}).get("coding") or [] + return _coding_key(codings[0]) if codings else None + + +def _ref_id(ref: Optional[str]) -> Optional[str]: + if not ref: + return None + return ref.rsplit("/", 1)[-1] if "/" in ref else ref + + +def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]: + if not isinstance(raw, dict): + return None + resource = raw.get("resource") if "resource" in raw else raw + return resource if isinstance(resource, dict) else None + + +def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + """Yield resource dicts from one parsed NDJSON object (Bundle or bare resource).""" + if "entry" in obj: + for entry in obj.get("entry") or []: + resource = entry.get("resource") + if isinstance(resource, dict): + yield resource + return + resource = _unwrap_resource_dict(obj) + if resource is not None: + yield resource + + +def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]: + """Yield parsed JSON objects from a plain or gzip-compressed NDJSON file.""" + opener = ( + gzip.open(path, "rt", encoding="utf-8", errors="replace") + if path.suffix == ".gz" + else open(path, encoding="utf-8", errors="replace") + ) + with opener as stream: + for line in stream: + line = line.strip() + if not line: + continue + parsed = orjson.loads(line) + if isinstance(parsed, dict): + yield parsed + + +# --------------------------------------------------------------------------- +# Resource field extraction +# --------------------------------------------------------------------------- + + +def _clinical_concept_key(res: Dict[str, Any]) -> Optional[str]: + """Resolve a stable token key from a FHIR resource.""" + resource_type = res.get("resourceType") + if resource_type == "MedicationRequest": + med_cc = res.get("medicationCodeableConcept") + if isinstance(med_cc, dict): + key = _first_coding(med_cc) + if key: + return key + med_ref = res.get("medicationReference") + if isinstance(med_ref, dict): + ref = med_ref.get("reference") + if ref: + return f"MedicationRequest/reference|{_ref_id(ref) or ref}" + return None + code = res.get("code") + return _first_coding(code) if isinstance(code, dict) else None + + +def patient_id_for_resource( + resource: Dict[str, Any], + resource_type: Optional[str] = None, +) -> Optional[str]: + resource_type = resource_type or resource.get("resourceType") + if resource_type == "Patient": + pid = resource.get("id") + return str(pid) if pid is not None else None + if resource_type in {"Encounter", "Condition", "Observation", "MedicationRequest", "Procedure"}: + return _ref_id((resource.get("subject") or {}).get("reference")) + return None + + +def _resource_time_string( + resource: Dict[str, Any], + resource_type: Optional[str] = None, +) -> Optional[str]: + resource_type = resource_type or resource.get("resourceType") + if resource_type == "Patient": + return resource.get("birthDate") + if resource_type == "Encounter": + return (resource.get("period") or {}).get("start") + if resource_type == "Condition": + return resource.get("onsetDateTime") or resource.get("recordedDate") + if resource_type == "Observation": + return resource.get("effectiveDateTime") or resource.get("issued") + if resource_type == "MedicationRequest": + return resource.get("authoredOn") + if resource_type == "Procedure": + return resource.get("performedDateTime") or resource.get("recordedDate") + return None + + +# --------------------------------------------------------------------------- +# Flattening +# --------------------------------------------------------------------------- + + +def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]: + """Map Patient.deceasedBoolean to stored "true"/"false"/None. + + FHIR JSON uses real booleans; some exports use strings. Python's + bool("false") is True, so we must not coerce with bool(). + """ + if value is None: + return None + if value is True: + return "true" + if value is False: + return "false" + if isinstance(value, str): + key = value.strip().lower() + if key in ("true", "1", "yes", "y", "t"): + return "true" + if key in ("false", "0", "no", "n", "f", ""): + return "false" + return None + if isinstance(value, (int, float)) and not isinstance(value, bool): + if value == 0: + return "false" + if value == 1: + return "true" + return None + return None + + +_RESOURCE_TYPE_TO_TABLE: Dict[str, str] = { + "Condition": "condition", + "Observation": "observation", + "MedicationRequest": "medication_request", + "Procedure": "procedure", +} + + +def _flatten_resource_to_table_row( + resource: Dict[str, Any], +) -> Optional[Tuple[str, Dict[str, Optional[str]]]]: + """Map one FHIR resource dict to (table_name, row_dict), or None if unsupported.""" + resource_type = resource.get("resourceType") + patient_id = patient_id_for_resource(resource, resource_type) + if not patient_id: + return None + + if resource_type == "Patient": + return "patient", { + "patient_id": patient_id, + "patient_fhir_id": str(resource.get("id") or patient_id), + "birth_date": resource.get("birthDate"), + "gender": resource.get("gender"), + "deceased_boolean": _normalize_deceased_boolean_for_storage(resource.get("deceasedBoolean")), + "deceased_datetime": resource.get("deceasedDateTime"), + } + + resource_id = str(resource.get("id")) if resource.get("id") is not None else None + event_time = _resource_time_string(resource, resource_type) + + if resource_type == "Encounter": + return "encounter", { + "patient_id": patient_id, + "resource_id": resource_id, + "encounter_id": resource_id, + "event_time": event_time, + "encounter_class": (resource.get("class") or {}).get("code"), + "encounter_end": (resource.get("period") or {}).get("end"), + } + + table_name = _RESOURCE_TYPE_TO_TABLE.get(resource_type) + if table_name is None: + return None + return table_name, { + "patient_id": patient_id, + "resource_id": resource_id, + "encounter_id": _ref_id((resource.get("encounter") or {}).get("reference")), + "event_time": event_time, + "concept_key": _clinical_concept_key(resource), + } + + +# --------------------------------------------------------------------------- +# Parquet writer +# --------------------------------------------------------------------------- + + +def _table_schema(table_name: str) -> pa.Schema: + return pa.schema([(col, pa.string()) for col in FHIR_TABLE_COLUMNS[table_name]]) + + +class _BufferedParquetWriter: + def __init__(self, path: Path, schema: pa.Schema, batch_size: int = 50_000) -> None: + self.path = path + self.schema = schema + self.batch_size = batch_size + self.rows: List[Dict[str, Any]] = [] + self.writer: Optional[pq.ParquetWriter] = None + self.path.parent.mkdir(parents=True, exist_ok=True) + + def add(self, row: Dict[str, Any]) -> None: + self.rows.append(row) + if len(self.rows) >= self.batch_size: + self.flush() + + def flush(self) -> None: + if not self.rows: + return + table = pa.Table.from_pylist(self.rows, schema=self.schema) + if self.writer is None: + self.writer = pq.ParquetWriter(str(self.path), self.schema) + self.writer.write_table(table) + self.rows.clear() + + def close(self) -> None: + self.flush() + if self.writer is None: + pq.write_table(pa.Table.from_pylist([], schema=self.schema), str(self.path)) + return + self.writer.close() + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]: + """Return sorted unique file paths under root matching glob pattern(s). + + Args: + root: Root directory to search under. + glob_pattern: Single glob string or sequence of glob strings. + + Returns: + Sorted list of matching files. Empty if no matches. + """ + patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern) + files: set[Path] = set() + for pat in patterns: + files.update(p for p in root.glob(pat) if p.is_file()) + return sorted(files, key=lambda p: str(p)) + + +def stream_fhir_ndjson_to_flat_tables( + root: Path, + glob_pattern: GlobPatternArg, + out_dir: Path, +) -> None: + """Stream NDJSON resources into normalized per-resource Parquet tables under out_dir. + + Args: + root: Root directory containing NDJSON/NDJSON.GZ files. + glob_pattern: Single glob string or sequence of glob strings. + out_dir: Output directory for per-resource-type Parquet tables. + Creates patient.parquet, encounter.parquet, condition.parquet, + observation.parquet, medication_request.parquet, procedure.parquet. + """ + out_dir.mkdir(parents=True, exist_ok=True) + writers = { + name: _BufferedParquetWriter(path=out_dir / FHIR_TABLE_FILE_NAMES[name], schema=_table_schema(name)) + for name in FHIR_TABLES + } + try: + for file_path in sorted_ndjson_files(root, glob_pattern): + for ndjson_obj in iter_ndjson_objects(file_path): + for resource in iter_resources_from_ndjson_obj(ndjson_obj): + result = _flatten_resource_to_table_row(resource) + if result is not None: + writers[result[0]].add(result[1]) + finally: + for writer in writers.values(): + writer.close() + + +def _sorted_patient_ids_from_flat_tables(table_dir: Path) -> List[str]: + patient_table = table_dir / FHIR_TABLE_FILE_NAMES["patient"] + if patient_table.exists(): + return ( + pl.scan_parquet(str(patient_table)) + .select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + frames = [ + pl.scan_parquet(str(table_dir / FHIR_TABLE_FILE_NAMES[t])).select("patient_id") + for t in FHIR_TABLES_FOR_PATIENT_IDS + ] + return ( + pl.concat(frames) + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + + +def filter_flat_tables_by_patient_ids( + source_dir: Path, + out_dir: Path, + keep_ids: Sequence[str], +) -> None: + """Filter all flattened tables to only include rows for the given patient IDs.""" + out_dir.mkdir(parents=True, exist_ok=True) + keep_set = set(keep_ids) + for name in FHIR_TABLES: + src = source_dir / FHIR_TABLE_FILE_NAMES[name] + dst = out_dir / FHIR_TABLE_FILE_NAMES[name] + pl.scan_parquet(str(src)).filter(pl.col("patient_id").is_in(keep_set)).sink_parquet(str(dst)) + + +# --------------------------------------------------------------------------- +# Synthetic fixtures (for tests and --quick-test mode) +# --------------------------------------------------------------------------- + + +def synthetic_mpf_one_patient_resources() -> List[Dict[str, Any]]: + return [ + {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"}, + { + "resourceType": "Encounter", + "id": "e1", + "subject": {"reference": "Patient/p-synth-1"}, + "period": {"start": "2020-06-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Condition", + "id": "c1", + "subject": {"reference": "Patient/p-synth-1"}, + "encounter": {"reference": "Encounter/e1"}, + "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]}, + "onsetDateTime": "2020-06-01T11:00:00Z", + }, + ] + + +def synthetic_mpf_two_patient_resources() -> List[Dict[str, Any]]: + return [ + *synthetic_mpf_one_patient_resources(), + {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True}, + { + "resourceType": "Encounter", + "id": "e-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "period": {"start": "2020-07-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", + "id": "o-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "encounter": {"reference": "Encounter/e-dead"}, + "effectiveDateTime": "2020-07-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]}, + }, + ] + + +def synthetic_mpf_one_patient_ndjson_text() -> str: + return "\n".join(orjson.dumps(r).decode("utf-8") for r in synthetic_mpf_one_patient_resources()) + "\n" + + +def synthetic_mpf_two_patient_ndjson_text() -> str: + return "\n".join(orjson.dumps(r).decode("utf-8") for r in synthetic_mpf_two_patient_resources()) + "\n" diff --git a/pyhealth/datasets/mimic4_fhir.py b/pyhealth/datasets/mimic4_fhir.py new file mode 100644 index 000000000..d9f8e428a --- /dev/null +++ b/pyhealth/datasets/mimic4_fhir.py @@ -0,0 +1,369 @@ +"""MIMIC-IV FHIR ingestion using flattened resource tables. + +Architecture +------------ +1. Stream NDJSON/NDJSON.GZ FHIR resources from disk. +2. Normalize each resource type into a 2D table (Patient, Encounter, Condition, + Observation, MedicationRequest, Procedure) via :mod:`~pyhealth.datasets.fhir_ingest`. +3. Feed those tables through the standard YAML-driven + :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task processing + operates on :class:`~pyhealth.data.Patient` and ``global_event_df`` rows. + +Module layout +------------- +- :mod:`~pyhealth.datasets.fhir_ingest` -- NDJSON → Parquet pipeline. +- :mod:`~pyhealth.datasets.fhir_cehr` -- CEHR tokenization, ConceptVocab, labels. +- This module -- :class:`MIMIC4FHIRDataset` wiring ingest + into :class:`~pyhealth.datasets.BaseDataset`. +""" + +from __future__ import annotations + +import functools +import hashlib +import itertools +import logging +import operator +import os +import shutil +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +import dask.dataframe as dd +import narwhals as nw +import orjson +import pandas as pd +import platformdirs +import polars as pl +from litdata.processing.data_processor import in_notebook +from yaml import safe_load + +from ..data import Patient +from .base_dataset import BaseDataset +from .fhir_cehr import ConceptVocab, ensure_special_tokens, warm_mpf_vocab_from_patient +from .fhir_ingest import ( + FHIR_SCHEMA_VERSION, + FHIR_TABLE_FILE_NAMES, + FHIR_TABLES, + _sorted_patient_ids_from_flat_tables, + filter_flat_tables_by_patient_ids, + stream_fhir_ndjson_to_flat_tables, +) + +logger = logging.getLogger(__name__) + + +def read_fhir_settings_yaml(path: Optional[str] = None) -> Dict[str, Any]: + if path is None: + path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_fhir.yaml") + with open(path, encoding="utf-8") as stream: + data = safe_load(stream) + return data if isinstance(data, dict) else {} + + +def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series: + if getattr(part.dtype, "tz", None) is not None: + part = part.dt.tz_localize(None) + return part.astype("datetime64[ms]") + + +class MIMIC4FHIRDataset(BaseDataset): + """MIMIC-IV FHIR with flattened resource tables and standard PyHealth task flow. + + Streams raw MIMIC-IV FHIR NDJSON/NDJSON.GZ exports into six flattened Parquet + tables then pipelines them through :class:`~pyhealth.datasets.BaseDataset` for + standard downstream task processing (global event dataframe, patient iteration, + task sampling). + + **Ingest flow (out-of-core):** + + 1. Scan NDJSON files matching ``glob_patterns`` (defaults to six Mimic* families). + 2. Parse and flatten each FHIR resource into a row in the appropriate table. + 3. Cache normalized tables under ``cache_dir / {uuid} / flattened_tables/``. + 4. Load tables into ``global_event_df`` via YAML config. + + **Cache fingerprinting:** includes ``glob_patterns`` and YAML digest, so changes + to either create a new independent cache. + """ + + def __init__( + self, + root: str, + config_path: Optional[str] = None, + glob_pattern: Optional[str] = None, + glob_patterns: Optional[Sequence[str]] = None, + max_patients: Optional[int] = None, + ingest_num_shards: Optional[int] = None, + vocab_path: Optional[str] = None, + cache_dir: Optional[str | Path] = None, + num_workers: int = 1, + dev: bool = False, + ) -> None: + """Initialize a MIMIC-IV FHIR dataset. + + Args: + root: Path to the NDJSON/NDJSON.GZ export directory. + config_path: Path to a custom YAML config. Defaults to + ``pyhealth/datasets/configs/mimic4_fhir.yaml``. + glob_pattern: Single glob for NDJSON files. Mutually exclusive + with ``glob_patterns``. + glob_patterns: Multiple glob patterns. Mutually exclusive with + ``glob_pattern``. + max_patients: Limit ingest to the first N unique patient IDs. + ingest_num_shards: Ignored; retained for API compatibility. + vocab_path: Path to a pre-built ConceptVocab JSON file. + cache_dir: Cache directory root (UUID subdir appended per config). + num_workers: Worker processes for task sampling. + dev: Development mode; limits to 1000 patients if max_patients is None. + """ + del ingest_num_shards + + default_cfg = os.path.join(os.path.dirname(__file__), "configs", "mimic4_fhir.yaml") + self._fhir_config_path = str(Path(config_path or default_cfg).resolve()) + self._fhir_settings = read_fhir_settings_yaml(self._fhir_config_path) + + if glob_pattern is not None and glob_patterns is not None: + raise ValueError("Pass at most one of glob_pattern and glob_patterns.") + if glob_patterns is not None: + self.glob_patterns: List[str] = list(glob_patterns) + elif glob_pattern is not None: + self.glob_patterns = [glob_pattern] + else: + raw_list = self._fhir_settings.get("glob_patterns") + if raw_list: + if not isinstance(raw_list, list): + raise TypeError("mimic4_fhir.yaml glob_patterns must be a list of strings.") + self.glob_patterns = [str(x) for x in raw_list] + elif self._fhir_settings.get("glob_pattern") is not None: + self.glob_patterns = [str(self._fhir_settings["glob_pattern"])] + else: + self.glob_patterns = ["**/*.ndjson.gz"] + + self.glob_pattern = ( + self.glob_patterns[0] if len(self.glob_patterns) == 1 else "; ".join(self.glob_patterns) + ) + self.max_patients = 1000 if dev and max_patients is None else max_patients + self.source_root = str(Path(root).expanduser().resolve()) + self.vocab = ( + ConceptVocab.load(vocab_path) + if vocab_path and os.path.isfile(vocab_path) + else ConceptVocab() + ) + super().__init__( + root=self.source_root, + tables=FHIR_TABLES, + dataset_name="mimic4_fhir", + config_path=self._fhir_config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: + try: + yaml_digest = hashlib.sha256(Path(self._fhir_config_path).read_bytes()).hexdigest()[:16] + except OSError: + yaml_digest = "missing" + identity = orjson.dumps( + { + "source_root": self.source_root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + "glob_patterns": self.glob_patterns, + "max_patients": self.max_patients, + "fhir_schema_version": FHIR_SCHEMA_VERSION, + "fhir_yaml_digest16": yaml_digest, + }, + option=orjson.OPT_SORT_KEYS, + ).decode("utf-8") + cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity)) + out = ( + Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_id + if cache_dir is None + else Path(cache_dir) / cache_id + ) + out.mkdir(parents=True, exist_ok=True) + logger.info("Cache dir: %s", out) + return out + + @property + def prepared_tables_dir(self) -> Path: + return self.cache_dir / "flattened_tables" + + def _ensure_prepared_tables(self) -> None: + root = Path(self.source_root) + if not root.is_dir(): + raise FileNotFoundError(f"MIMIC4 FHIR root not found: {root}") + + expected = [self.prepared_tables_dir / FHIR_TABLE_FILE_NAMES[t] for t in FHIR_TABLES] + if all(p.is_file() for p in expected): + return + if self.prepared_tables_dir.exists(): + shutil.rmtree(self.prepared_tables_dir) + + try: + staging_root = self.create_tmpdir() + staging = staging_root / "flattened_fhir_tables" + staging.mkdir(parents=True, exist_ok=True) + stream_fhir_ndjson_to_flat_tables(root, self.glob_patterns, staging) + + if self.max_patients is None: + shutil.move(str(staging), str(self.prepared_tables_dir)) + return + + filtered_root = self.create_tmpdir() + filtered = filtered_root / "flattened_fhir_tables_filtered" + patient_ids = _sorted_patient_ids_from_flat_tables(staging) + filter_flat_tables_by_patient_ids(staging, filtered, patient_ids[: self.max_patients]) + shutil.move(str(filtered), str(self.prepared_tables_dir)) + finally: + self.clean_tmpdir() + + def _event_transform(self, output_dir: Path) -> None: + self._ensure_prepared_tables() + super()._event_transform(output_dir) + + def load_table(self, table_name: str) -> dd.DataFrame: + """Load one flattened Parquet table into the standard event schema. + + Deviations from BaseDataset.load_table (which reads CSV via _scan_csv_tsv_gz): + - Reads from pre-built Parquet under prepared_tables_dir. + - Timestamp parsing uses errors="coerce" + utc=True (FHIR ISO strings + include timezone suffix or are partial dates). + - Strips tz-aware timestamps to naive UTC for Dask compatibility. + - Drops rows with null patient_id before returning. + """ + assert self.config is not None + if table_name not in self.config.tables: + raise ValueError(f"Table {table_name} not found in config") + + table_cfg = self.config.tables[table_name] + path = self.prepared_tables_dir / table_cfg.file_path + if not path.exists(): + raise FileNotFoundError(f"Flattened table not found: {path}") + + logger.info("Scanning FHIR flattened table: %s from %s", table_name, path) + df: dd.DataFrame = dd.read_parquet( + str(path), split_row_groups=True, blocksize="64MB" + ).replace("", pd.NA) + df = df.rename(columns=str.lower) + + preprocess_func = getattr(self, f"preprocess_{table_name}", None) + if preprocess_func is not None: + logger.info("Preprocessing FHIR table: %s with %s", table_name, preprocess_func.__name__) + df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr] + + for join_cfg in table_cfg.join: + join_path = self.prepared_tables_dir / Path(join_cfg.file_path).name + if not join_path.exists(): + raise FileNotFoundError(f"FHIR join table not found: {join_path}") + logger.info("Joining FHIR table %s with %s", table_name, join_path) + join_df: dd.DataFrame = dd.read_parquet( + str(join_path), split_row_groups=True, blocksize="64MB" + ).replace("", pd.NA) + join_df = join_df.rename(columns=str.lower) + join_key = join_cfg.on.lower() + cols = [c.lower() for c in join_cfg.columns] + df = df.merge(join_df[[join_key] + cols], on=join_key, how=join_cfg.how) + + ts_col = table_cfg.timestamp + if ts_col: + ts = ( + functools.reduce(operator.add, (df[c].astype("string") for c in ts_col)) + if isinstance(ts_col, list) + else df[ts_col].astype("string") + ) + ts = dd.to_datetime(ts, format=table_cfg.timestamp_format, errors="coerce", utc=True) + df = df.assign(timestamp=ts.map_partitions(_strip_tz_to_naive_ms)) + else: + df = df.assign(timestamp=pd.NaT) + + if table_cfg.patient_id: + df = df.assign(patient_id=df[table_cfg.patient_id].astype("string")) + else: + df = df.reset_index(drop=True) + df = df.assign(patient_id=df.index.astype("string")) + + df = df.dropna(subset=["patient_id"]) + df = df.assign(event_type=table_name) + rename_attr = {attr.lower(): f"{table_name}/{attr}" for attr in table_cfg.attributes} + df = df.rename(columns=rename_attr) + return df[["patient_id", "event_type", "timestamp"] + [rename_attr[a.lower()] for a in table_cfg.attributes]] + + @property + def unique_patient_ids(self) -> List[str]: + if self._unique_patient_ids is None: + self._unique_patient_ids = ( + self.global_event_df + .select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + logger.info("Found %d unique patient IDs", len(self._unique_patient_ids)) + return self._unique_patient_ids + + def set_task( + self, + task: Any = None, + num_workers: Optional[int] = None, + input_processors: Optional[Any] = None, + output_processors: Optional[Any] = None, + ) -> Any: + self._main_guard(self.set_task.__name__) + if task is None: + raise ValueError("Pass a task instance, e.g. MPFClinicalPredictionTask(max_len=512).") + + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + if isinstance(task, MPFClinicalPredictionTask): + worker_count = ( + 1 if in_notebook() else (num_workers if num_workers is not None else self.num_workers) + ) + warmup_pids = self._mpf_patient_ids_for_task(task) + patient_count = len(warmup_pids) + effective_workers = min(worker_count, patient_count) if patient_count else 1 + ensure_special_tokens(self.vocab) + self._warm_mpf_vocabulary(task, warmup_pids) + task.frozen_vocab = effective_workers > 1 + task.vocab = self.vocab + task._specials = ensure_special_tokens(self.vocab) + + return super().set_task(task, num_workers, input_processors, output_processors) + + def _mpf_patient_ids_for_task(self, task: Any) -> List[str]: + return ( + task.pre_filter(self.global_event_df) + .select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + .sort() + .to_list() + ) + + def _warm_mpf_vocabulary(self, task: Any, patient_ids: List[str]) -> None: + clinical_cap = max(0, task.max_len - 2) + base = self.global_event_df + for batch in itertools.batched(patient_ids, 128): + patients = ( + base.filter(pl.col("patient_id").is_in(batch)) + .collect(engine="streaming") + .partition_by("patient_id", as_dict=True) + ) + for patient_key, patient_df in patients.items(): + warm_mpf_vocab_from_patient( + self.vocab, Patient(patient_id=patient_key[0], data_source=patient_df), clinical_cap + ) + + def gather_samples(self, task: Any) -> List[Dict[str, Any]]: + task.vocab = self.vocab + task._specials = None + task.frozen_vocab = False + samples: List[Dict[str, Any]] = [] + for patient in self.iter_patients(): + samples.extend(task(patient)) + return samples diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..86486dab8 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -38,6 +38,7 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .ehrmamba import EHRMamba, MambaBlock +from .ehrmamba_cehr import EHRMambaCEHR from .vae import VAE from .vision_embedding import VisionEmbeddingModel from .text_embedding import TextEmbedding diff --git a/pyhealth/models/cehr_embeddings.py b/pyhealth/models/cehr_embeddings.py new file mode 100644 index 000000000..7974a699e --- /dev/null +++ b/pyhealth/models/cehr_embeddings.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Vector Institute / Odyssey authors +# +# Derived from Odyssey (https://github.com/VectorInstitute/odyssey): +# odyssey/models/embeddings.py — MambaEmbeddingsForCEHR, TimeEmbeddingLayer, VisitEmbedding +# Modifications: removed HuggingFace MambaConfig dependency; explicit constructor args. + +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch import nn + + +class TimeEmbeddingLayer(nn.Module): + """Embedding layer for time features (sinusoidal).""" + + def __init__(self, embedding_size: int, is_time_delta: bool = False): + super().__init__() + self.embedding_size = embedding_size + self.is_time_delta = is_time_delta + self.w = nn.Parameter(torch.empty(1, self.embedding_size)) + self.phi = nn.Parameter(torch.empty(1, self.embedding_size)) + nn.init.xavier_uniform_(self.w) + nn.init.xavier_uniform_(self.phi) + + def forward(self, time_stamps: torch.Tensor) -> torch.Tensor: + if self.is_time_delta: + time_stamps = torch.cat( + (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]), + dim=-1, + ) + time_stamps = time_stamps.float() + next_input = time_stamps.unsqueeze(-1) * self.w + self.phi + return torch.sin(next_input) + + +class VisitEmbedding(nn.Module): + """Embedding layer for visit segments.""" + + def __init__(self, visit_order_size: int, embedding_size: int): + super().__init__() + self.embedding = nn.Embedding(visit_order_size, embedding_size) + + def forward(self, visit_segments: torch.Tensor) -> torch.Tensor: + return self.embedding(visit_segments) + + +class MambaEmbeddingsForCEHR(nn.Module): + """CEHR-style combined embeddings for Mamba (concept + type + time + age + visit).""" + + def __init__( + self, + vocab_size: int, + hidden_size: int, + pad_token_id: int = 0, + type_vocab_size: int = 9, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_order_size: int = 3, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.1, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.pad_token_id = pad_token_id + self.type_vocab_size = type_vocab_size + self.max_num_visits = max_num_visits + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.visit_order_embeddings = nn.Embedding(max_num_visits, hidden_size) + self.time_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, is_time_delta=True + ) + self.age_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, is_time_delta=False + ) + self.visit_segment_embeddings = VisitEmbedding( + visit_order_size=visit_order_size, embedding_size=hidden_size + ) + self.scale_back_concat_layer = nn.Linear( + hidden_size + 2 * time_embeddings_size, hidden_size + ) + self.tanh = nn.Tanh() + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids_batch: torch.Tensor, + time_stamps: torch.Tensor, + ages: torch.Tensor, + visit_orders: torch.Tensor, + visit_segments: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.word_embeddings(input_ids) + time_stamps_embeds = self.time_embeddings(time_stamps) + ages_embeds = self.age_embeddings(ages) + visit_segments_embeds = self.visit_segment_embeddings(visit_segments) + visit_order_embeds = self.visit_order_embeddings(visit_orders) + token_type_embeds = self.token_type_embeddings(token_type_ids_batch) + concat_in = torch.cat( + (inputs_embeds, time_stamps_embeds, ages_embeds), dim=-1 + ) + h = self.tanh(self.scale_back_concat_layer(concat_in)) + embeddings = h + token_type_embeds + visit_order_embeds + visit_segments_embeds + embeddings = self.dropout(embeddings) + return self.LayerNorm(embeddings) diff --git a/pyhealth/models/ehrmamba_cehr.py b/pyhealth/models/ehrmamba_cehr.py new file mode 100644 index 000000000..cd555629c --- /dev/null +++ b/pyhealth/models/ehrmamba_cehr.py @@ -0,0 +1,117 @@ +"""EHRMamba with CEHR-style embeddings for single-stream FHIR token sequences.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from pyhealth.datasets import SampleDataset + +from .base_model import BaseModel +from .cehr_embeddings import MambaEmbeddingsForCEHR +from .ehrmamba import MambaBlock +from .utils import get_rightmost_masked_timestep + + +class EHRMambaCEHR(BaseModel): + """Mamba backbone over CEHR embeddings (FHIR / MPF pipeline). + + Args: + dataset: Fitted :class:`~pyhealth.datasets.SampleDataset` with MPF task schema. + vocab_size: Concept embedding vocabulary size (typically ``task.vocab.vocab_size``). + embedding_dim: Hidden size (``hidden_size`` in CEHR embeddings). + num_layers: Number of :class:`~pyhealth.models.ehrmamba.MambaBlock` layers. + pad_token_id: Padding id for masking (default 0). + state_size: SSM state size per channel. + conv_kernel: Causal conv kernel in each block. + dropout: Dropout before classifier. + """ + + def __init__( + self, + dataset: SampleDataset, + vocab_size: int, + embedding_dim: int = 128, + num_layers: int = 2, + pad_token_id: int = 0, + state_size: int = 16, + conv_kernel: int = 4, + dropout: float = 0.1, + type_vocab_size: int = 16, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_segment_vocab: int = 3, + ): + super().__init__(dataset=dataset) + self.embedding_dim = embedding_dim + self.num_layers = num_layers + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + + assert len(self.label_keys) == 1, "EHRMambaCEHR supports single label key only" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embeddings = MambaEmbeddingsForCEHR( + vocab_size=vocab_size, + hidden_size=embedding_dim, + pad_token_id=pad_token_id, + type_vocab_size=type_vocab_size, + max_num_visits=max_num_visits, + time_embeddings_size=time_embeddings_size, + visit_order_size=visit_segment_vocab, + ) + self.blocks = nn.ModuleList( + [ + MambaBlock( + d_model=embedding_dim, + state_size=state_size, + conv_kernel=conv_kernel, + ) + for _ in range(num_layers) + ] + ) + self.dropout = nn.Dropout(dropout) + out_dim = self.get_output_size() + self.fc = nn.Linear(embedding_dim, out_dim) + self._forecasting_head: Optional[nn.Module] = None + + def forward_forecasting(self, **kwargs: Any) -> Optional[torch.Tensor]: + """Optional next-token / forecasting head (extension point; not implemented).""" + + return None + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + concept_ids = kwargs["concept_ids"].to(self.device).long() + token_type_ids = kwargs["token_type_ids"].to(self.device).long() + time_stamps = kwargs["time_stamps"].to(self.device).float() + ages = kwargs["ages"].to(self.device).float() + visit_orders = kwargs["visit_orders"].to(self.device).long() + visit_segments = kwargs["visit_segments"].to(self.device).long() + + x = self.embeddings( + input_ids=concept_ids, + token_type_ids_batch=token_type_ids, + time_stamps=time_stamps, + ages=ages, + visit_orders=visit_orders, + visit_segments=visit_segments, + ) + mask = concept_ids != self.pad_token_id + for blk in self.blocks: + x = blk(x) + pooled = get_rightmost_masked_timestep(x, mask) + logits = self.fc(self.dropout(pooled)) + y_true = kwargs[self.label_key].to(self.device).float() + if y_true.dim() == 1: + y_true = y_true.unsqueeze(-1) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py index 67edc010e..45cd6608d 100644 --- a/pyhealth/models/utils.py +++ b/pyhealth/models/utils.py @@ -44,3 +44,31 @@ def get_last_visit(hidden_states, mask): last_hidden_states = torch.gather(hidden_states, 1, last_visit) last_hidden_state = last_hidden_states[:, 0, :] return last_hidden_state + + +def get_rightmost_masked_timestep(hidden_states, mask): + """Gather hidden state at the last True position in ``mask`` per row. + + Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a + contiguous prefix; it picks the maximum index where ``mask`` is True. + Use for MPF / CEHR layouts where padding can appear between boundary tokens. + + Args: + hidden_states: ``[batch, seq_len, hidden_size]``. + mask: ``[batch, seq_len]`` bool. + + Returns: + Tensor ``[batch, hidden_size]``. + """ + if mask is None: + return hidden_states[:, -1, :] + batch, seq_len, hidden = hidden_states.shape + device = hidden_states.device + idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand( + batch, -1 + ) + idx_m = torch.where(mask, idx, torch.full_like(idx, -1)) + last_idx = idx_m.max(dim=1).values.clamp(min=0) + last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden) + gathered = torch.gather(hidden_states, 1, last_idx) + return gathered[:, 0, :] diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..ffdf99560 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,11 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task + + +def __getattr__(name: str): + if name == "MPFClinicalPredictionTask": + from .mpf_clinical_prediction import MPFClinicalPredictionTask + + return MPFClinicalPredictionTask + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py new file mode 100644 index 000000000..34927d8ab --- /dev/null +++ b/pyhealth/tasks/mpf_clinical_prediction.py @@ -0,0 +1,152 @@ +"""Multitask Prompted Fine-tuning (MPF) clinical prediction on FHIR timelines.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch + +from pyhealth.data import Patient +from pyhealth.datasets.fhir_cehr import ( + ConceptVocab, + build_cehr_sequences, + ensure_special_tokens, + infer_mortality_label, +) + +from .base_task import BaseTask + + +def _pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]: + if len(seq) > max_len: + return seq[-max_len:] + return seq + [pad] * (max_len - len(seq)) + + +def _pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]: + if len(seq) > max_len: + return seq[-max_len:] + return seq + [pad] * (max_len - len(seq)) + + +def _left_pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]: + if len(seq) > max_len: + return seq[-max_len:] + return [pad] * (max_len - len(seq)) + seq + + +def _left_pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]: + if len(seq) > max_len: + return seq[-max_len:] + return [pad] * (max_len - len(seq)) + seq + + +class MPFClinicalPredictionTask(BaseTask): + """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens. + + Works on :class:`~pyhealth.data.Patient` via the standard + ``global_event_df`` / :meth:`~pyhealth.datasets.MIMIC4FHIRDataset.set_task` + path. For :meth:`set_task`, + :class:`~pyhealth.datasets.MIMIC4FHIRDataset` reserves specials, warms concept + keys in the main process over the same patient cohort as + :meth:`~pyhealth.tasks.base_task.BaseTask.pre_filter` (including when LitData + skips ``_task_transform`` on cache hit), and sets :attr:`frozen_vocab` when + multiple workers run :meth:`~pyhealth.datasets.BaseDataset._task_transform` so + worker processes do not race on :class:`~pyhealth.datasets.mimic4_fhir.ConceptVocab`. + + Attributes: + max_len: Truncated sequence length (must be >= 2 for boundary tokens). + use_mpf: If True, use ```` / ```` specials; else ```` / ````. + vocab: Shared concept vocabulary (usually the dataset's vocab). + frozen_vocab: If True, do not add new concept ids (post-warmup parallel path). + """ + + task_name: str = "MPFClinicalPredictionFHIR" + input_schema: Dict[str, Any] = { + "concept_ids": ("tensor", {"dtype": torch.long}), + "token_type_ids": ("tensor", {"dtype": torch.long}), + "time_stamps": "tensor", + "ages": "tensor", + "visit_orders": ("tensor", {"dtype": torch.long}), + "visit_segments": ("tensor", {"dtype": torch.long}), + } + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__(self, max_len: int = 512, use_mpf: bool = True) -> None: + if max_len < 2: + raise ValueError("max_len must be >= 2 for MPF boundary tokens") + self.max_len = max_len + self.use_mpf = use_mpf + self.vocab: Optional[ConceptVocab] = None + self._specials: Optional[Dict[str, int]] = None + self.frozen_vocab: bool = False + + def _ensure_vocab(self) -> ConceptVocab: + if self.vocab is None: + self.vocab = ConceptVocab() + if self._specials is None: + self._specials = ensure_special_tokens(self.vocab) + return self.vocab + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Build one labeled sample dict per patient. + + Args: + patient: A tabular :class:`~pyhealth.data.Patient`. + + Returns: + A one-element list with ``concept_ids``, tensor-ready feature lists, and + ``label`` (0/1). Boundary tokens are always included; when + ``max_len == 2`` the sequence is ````/```` and ```` only. + """ + vocab = self._ensure_vocab() + pid = patient.patient_id + clinical_cap = max(0, self.max_len - 2) + ( + concept_ids, + token_types, + time_stamps, + ages, + visit_orders, + visit_segments, + ) = build_cehr_sequences( + patient, + vocab, + clinical_cap, + grow_vocab=not self.frozen_vocab, + ) + + assert self._specials is not None + mor_id = self._specials[""] if self.use_mpf else self._specials[""] + reg_id = self._specials[""] + z0 = 0 + zf = 0.0 + concept_ids = [mor_id] + concept_ids + [reg_id] + token_types = [z0] + token_types + [z0] + time_stamps = [zf] + time_stamps + [zf] + ages = [zf] + ages + [zf] + visit_orders = [z0] + visit_orders + [z0] + visit_segments = [z0] + visit_segments + [z0] + + ml = self.max_len + concept_ids = _left_pad_int(concept_ids, ml, vocab.pad_id) + token_types = _left_pad_int(token_types, ml, 0) + time_stamps = _left_pad_float(time_stamps, ml, 0.0) + ages = _left_pad_float(ages, ml, 0.0) + visit_orders = _left_pad_int(visit_orders, ml, 0) + visit_segments = _left_pad_int(visit_segments, ml, 0) + + label = infer_mortality_label(patient) + return [ + { + "patient_id": pid, + "visit_id": f"{pid}-0", + "concept_ids": concept_ids, + "token_type_ids": token_types, + "time_stamps": time_stamps, + "ages": ages, + "visit_orders": visit_orders, + "visit_segments": visit_segments, + "label": label, + } + ] diff --git a/pyproject.toml b/pyproject.toml index 934d4f1bb..58afeb1b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "more-itertools~=10.8.0", "einops>=0.8.0", "linear-attention-transformer>=0.19.1", + "orjson~=3.10", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/mimic4_fhir_ndjson_fixtures.py b/tests/core/mimic4_fhir_ndjson_fixtures.py new file mode 100644 index 000000000..2f48e2ba1 --- /dev/null +++ b/tests/core/mimic4_fhir_ndjson_fixtures.py @@ -0,0 +1,30 @@ +"""NDJSON file bodies for :mod:`tests.core.test_mimic4_fhir_dataset` (disk-only ingest).""" + +from __future__ import annotations + +from pathlib import Path + +from pyhealth.datasets.fhir_ingest import ( + synthetic_mpf_one_patient_ndjson_text, + synthetic_mpf_two_patient_ndjson_text, +) + + +def ndjson_one_patient_text() -> str: + return synthetic_mpf_one_patient_ndjson_text() + + +def ndjson_two_class_text() -> str: + return synthetic_mpf_two_patient_ndjson_text() + + +def write_two_class_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + path = directory / name + path.write_text(ndjson_two_class_text(), encoding="utf-8") + return path + + +def write_one_patient_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + path = directory / name + path.write_text(ndjson_one_patient_text(), encoding="utf-8") + return path diff --git a/tests/core/test_ehrmamba_cehr.py b/tests/core/test_ehrmamba_cehr.py new file mode 100644 index 000000000..6c0ac79e0 --- /dev/null +++ b/tests/core/test_ehrmamba_cehr.py @@ -0,0 +1,124 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import EHRMambaCEHR +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + +def _tiny_samples(seq: int = 16) -> tuple: + from pyhealth.datasets.fhir_cehr import ConceptVocab, ensure_special_tokens + + task = MPFClinicalPredictionTask(max_len=seq, use_mpf=True) + task.vocab = ConceptVocab() + sp = ensure_special_tokens(task.vocab) + mid = task.vocab.add_token("test|filler") + samples = [] + for lab in (0, 1): + samples.append( + { + "patient_id": f"p{lab}", + "visit_id": f"v{lab}", + "concept_ids": [sp[""]] + [mid] * (seq - 2) + [sp[""]], + "token_type_ids": [0] * seq, + "time_stamps": [0.0] * seq, + "ages": [50.0] * seq, + "visit_orders": [0] * seq, + "visit_segments": [0] * seq, + "label": lab, + } + ) + return samples, task + + +class TestEHRMambaCEHR(unittest.TestCase): + def test_readout_pools_rightmost_non_pad(self) -> None: + """MPF padding between tokens must not make pooling pick a pad position.""" + + from pyhealth.models.utils import ( + get_last_visit, + get_rightmost_masked_timestep, + ) + + h = torch.tensor([[[1.0, 0.0], [2.0, 0.0], [0.0, 0.0], [99.0, 0.0]]]) + m = torch.tensor([[True, True, False, True]]) + out = get_rightmost_masked_timestep(h, m) + self.assertTrue(torch.allclose(out[0], torch.tensor([99.0, 0.0]))) + wrong = get_last_visit(h, m) + self.assertFalse(torch.allclose(out[0], wrong[0])) + + def test_end_to_end_fhir_pipeline(self) -> None: + import tempfile + from pathlib import Path + + from pyhealth.datasets import MIMIC4FHIRDataset, create_sample_dataset + from pyhealth.datasets import get_dataloader + + from tests.core.mimic4_fhir_ndjson_fixtures import write_two_class_ndjson + + task = MPFClinicalPredictionTask(max_len=32, use_mpf=True) + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp + ) + samples = ds.gather_samples(task) + sample_ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name="fhir_test", + ) + vocab_size = max(max(s["concept_ids"]) for s in samples) + 1 + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=64, + num_layers=1, + ) + batch = next( + iter(get_dataloader(sample_ds, batch_size=2, shuffle=False)) + ) + out = model(**batch) + self.assertIn("loss", out) + out["loss"].backward() + + def test_forward_backward(self) -> None: + samples, task = _tiny_samples() + ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + ) + vocab_size = max(max(s["concept_ids"]) for s in samples) + 1 + model = EHRMambaCEHR( + dataset=ds, + vocab_size=vocab_size, + embedding_dim=64, + num_layers=1, + state_size=8, + ) + batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False))) + out = model(**batch) + self.assertEqual(out["logit"].shape[0], 2) + out["loss"].backward() + + def test_eval_mode(self) -> None: + samples, task = _tiny_samples() + ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + ) + vocab_size = max(max(s["concept_ids"]) for s in samples) + 1 + model = EHRMambaCEHR(dataset=ds, vocab_size=vocab_size, embedding_dim=32, num_layers=1) + model.eval() + with torch.no_grad(): + batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False))) + out = model(**batch) + self.assertIn("y_prob", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic4_fhir_dataset.py b/tests/core/test_mimic4_fhir_dataset.py new file mode 100644 index 000000000..1611b1979 --- /dev/null +++ b/tests/core/test_mimic4_fhir_dataset.py @@ -0,0 +1,700 @@ +import gzip +import tempfile +import unittest +from pathlib import Path +from typing import Dict, List + +import orjson +import polars as pl + +from pyhealth.data import Patient +from pyhealth.datasets import MIMIC4FHIRDataset +from pyhealth.datasets.fhir_ingest import ( + _flatten_resource_to_table_row, + synthetic_mpf_two_patient_ndjson_text, +) +from pyhealth.datasets.fhir_cehr import ( + ConceptVocab, + build_cehr_sequences, + collect_cehr_timeline_events, + infer_mortality_label, +) + +from tests.core.mimic4_fhir_ndjson_fixtures import ( + ndjson_two_class_text, + write_one_patient_ndjson, + write_two_class_ndjson, +) + + +def _third_patient_loinc_resources() -> List[Dict[str, object]]: + return [ + { + "resourceType": "Patient", + "id": "p-synth-3", + "birthDate": "1960-01-01", + }, + { + "resourceType": "Encounter", + "id": "e3", + "subject": {"reference": "Patient/p-synth-3"}, + "period": {"start": "2020-08-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", + "id": "o3", + "subject": {"reference": "Patient/p-synth-3"}, + "encounter": {"reference": "Encounter/e3"}, + "effectiveDateTime": "2020-08-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "999-9"}]}, + }, + ] + + +def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + lines = synthetic_mpf_two_patient_ndjson_text().strip().split("\n") + lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources()) + path = directory / name + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path + + +def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient: + return Patient(patient_id=patient_id, data_source=pl.DataFrame(rows)) + + +class TestDeceasedBooleanFlattening(unittest.TestCase): + def test_string_false_not_coerced_by_python_bool(self) -> None: + """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``.""" + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-false", + "deceasedBoolean": "false", + } + ) + self.assertIsNotNone(row) + _table, payload = row + self.assertEqual(payload.get("deceased_boolean"), "false") + + def test_string_true_parsed(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-true", + "deceasedBoolean": "true", + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), "true") + + def test_json_booleans_unchanged(self) -> None: + for raw, expected in ((True, "true"), (False, "false")): + with self.subTest(raw=raw): + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-bool", + "deceasedBoolean": raw, + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), expected) + + def test_unknown_deceased_type_stored_as_none(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-garbage", + "deceasedBoolean": {"unexpected": "object"}, + } + ) + self.assertIsNotNone(row) + self.assertIsNone(row[1].get("deceased_boolean")) + + def test_infer_mortality_respects_string_false_row(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "event_type": "patient", + "timestamp": "2020-01-01T00:00:00", + "patient/deceased_boolean": "false", + }, + ], + ) + self.assertEqual(infer_mortality_label(patient), 0) + + +class TestMIMIC4FHIRDataset(unittest.TestCase): + def test_concept_vocab_from_json_empty_token_to_id(self) -> None: + v = ConceptVocab.from_json({"token_to_id": {}}) + self.assertIn("", v.token_to_id) + self.assertIn("", v.token_to_id) + self.assertEqual(v._next_id, 2) + + def test_concept_vocab_from_json_empty_respects_next_id(self) -> None: + v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50}) + self.assertEqual(v._next_id, 50) + + def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None: + from pyhealth.datasets.fhir_ingest import sorted_ndjson_files + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8") + (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8") + (root / "notes.txt").write_text("z", encoding="utf-8") + wide = sorted_ndjson_files(root, "**/*.ndjson.gz") + narrow = sorted_ndjson_files( + root, + ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"], + ) + self.assertEqual(len(wide), 2) + self.assertEqual(len(narrow), 1) + self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz") + + def test_dataset_accepts_glob_patterns_kwarg(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp + ) + self.assertEqual(ds.glob_patterns, ["*.ndjson"]) + _ = ds.global_event_df.collect(engine="streaming") + + def test_dataset_rejects_both_glob_kwargs(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(ValueError): + MIMIC4FHIRDataset( + root=tmp, + glob_pattern="*.ndjson", + glob_patterns=["*.ndjson"], + cache_dir=tmp, + ) + + def test_disk_fixture_resolves_events_per_patient(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + sub = ds.global_event_df.filter(pl.col("patient_id") == "p-synth-1").collect( + engine="streaming" + ) + self.assertGreaterEqual(len(sub), 2) + self.assertIn("condition/concept_key", sub.columns) + + def test_prepared_flat_tables_exist(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + _ = ds.global_event_df.collect(engine="streaming") + prepared = ds.prepared_tables_dir + self.assertTrue((prepared / "patient.parquet").is_file()) + self.assertTrue((prepared / "encounter.parquet").is_file()) + self.assertTrue((prepared / "condition.parquet").is_file()) + + def test_build_cehr_non_empty(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + task = MPFClinicalPredictionTask(max_len=64, use_mpf=True) + ds.gather_samples(task) + self.assertIsInstance(ds.vocab, ConceptVocab) + self.assertGreater(ds.vocab.vocab_size, 2) + + def test_set_task_vocab_warm_on_litdata_cache_hit(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + task_kw = {"max_len": 64, "use_mpf": True} + ds1 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + ds1.set_task(MPFClinicalPredictionTask(**task_kw), num_workers=1) + warm_size = ds1.vocab.vocab_size + self.assertGreater(warm_size, 6) + ds2 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + ds2.set_task(MPFClinicalPredictionTask(**task_kw), num_workers=1) + self.assertEqual(ds2.vocab.vocab_size, warm_size) + + def test_mortality_heuristic(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + samples = ds.gather_samples(MPFClinicalPredictionTask(max_len=64, use_mpf=False)) + self.assertEqual({s["label"] for s in samples}, {0, 1}) + + def test_infer_deceased(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + dead = ds.get_patient("p-synth-2") + self.assertEqual(infer_mortality_label(dead), 1) + + def test_disk_ndjson_gz_physionet_style(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + gz_path = Path(tmp) / "fixture.ndjson.gz" + with gzip.open(gz_path, "wt", encoding="utf-8") as gz: + gz.write(ndjson_two_class_text()) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5) + self.assertGreaterEqual(len(ds.unique_patient_ids), 1) + + def test_disk_ndjson_temp_dir(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", max_patients=5) + self.assertEqual(len(ds.unique_patient_ids), 2) + samples = ds.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)) + self.assertGreaterEqual(len(samples), 1) + for sample in samples: + self.assertIn("concept_ids", sample) + self.assertIn("label", sample) + + def test_global_event_df_schema_and_flattened_columns(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + df = ds.global_event_df.collect(engine="streaming") + self.assertGreater(len(df), 0) + self.assertIn("patient_id", df.columns) + self.assertIn("timestamp", df.columns) + self.assertIn("event_type", df.columns) + self.assertIn("condition/concept_key", df.columns) + self.assertIn("observation/concept_key", df.columns) + self.assertIn("patient/deceased_boolean", df.columns) + + def test_set_task_parity_with_gather_samples_ndjson(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp), name="fx.ndjson") + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1 + ) + ref = sorted( + ds.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)), + key=lambda s: s["patient_id"], + ) + sample_ds = ds.set_task( + MPFClinicalPredictionTask(max_len=48, use_mpf=True), num_workers=1 + ) + got = sorted( + [sample_ds[i] for i in range(len(sample_ds))], + key=lambda s: s["patient_id"], + ) + self.assertEqual(len(got), len(ref)) + for expected, actual in zip(ref, got): + self.assertEqual(expected["label"], int(actual["label"])) + actual_ids = actual["concept_ids"] + if hasattr(actual_ids, "tolist"): + actual_ids = actual_ids.tolist() + self.assertEqual(expected["concept_ids"], actual_ids) + + def test_gather_samples_resets_frozen_vocab_after_set_task(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp_a, tempfile.TemporaryDirectory() as tmp_b: + write_two_class_ndjson(Path(tmp_a), name="a.ndjson") + write_two_class_ndjson(Path(tmp_b), name="b.ndjson") + ds_a = MIMIC4FHIRDataset( + root=tmp_a, glob_pattern="*.ndjson", cache_dir=tmp_a, num_workers=1 + ) + ds_b = MIMIC4FHIRDataset( + root=tmp_b, glob_pattern="*.ndjson", cache_dir=tmp_b, num_workers=1 + ) + task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) + ds_a.set_task(task, num_workers=1) + self.assertFalse(task.frozen_vocab) + + ref = sorted( + ds_b.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)), + key=lambda s: s["patient_id"], + ) + got = sorted(ds_b.gather_samples(task), key=lambda s: s["patient_id"]) + self.assertEqual(len(got), len(ref)) + for expected, actual in zip(ref, got): + self.assertEqual(expected["label"], actual["label"]) + self.assertEqual(expected["concept_ids"], actual["concept_ids"]) + + def test_set_task_multi_worker_sets_frozen_vocab(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2 + ) + task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) + ds.set_task(task, num_workers=2) + self.assertTrue(task.frozen_vocab) + + def test_mpf_pre_filter_vocab_warmup_excludes_dropped_patients(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + class TwoPatientMPFTask(MPFClinicalPredictionTask): + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id").is_in(["p-synth-1", "p-synth-2"])) + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_plus_third_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1 + ) + self.assertEqual(len(ds.unique_patient_ids), 3) + task = TwoPatientMPFTask(max_len=48, use_mpf=True) + ds.set_task(task, num_workers=1) + self.assertNotIn("http://loinc.org|999-9", ds.vocab.token_to_id) + self.assertIn("http://loinc.org|789-0", ds.vocab.token_to_id) + + def test_mpf_pre_filter_patient_ids_drive_effective_workers(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + class OnePatientMPFTask(MPFClinicalPredictionTask): + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id") == "p-synth-1") + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2 + ) + task = OnePatientMPFTask(max_len=48, use_mpf=True) + warmup_pids = ds._mpf_patient_ids_for_task(task) + self.assertEqual(warmup_pids, ["p-synth-1"]) + effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1 + self.assertEqual(effective_workers, 1) + + def test_encounter_reference_requires_exact_id(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-02T10:00:00", + "encounter/encounter_id": "e10", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-02T11:00:00", + "condition/encounter_id": "e10", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99", + }, + ], + ) + vocab = ConceptVocab() + concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64) + tid = vocab["http://hl7.org/fhir/sid/icd-10-cm|I99"] + self.assertEqual(concept_ids.count(tid), 1) + + def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "ea", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "eb", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], + ) + vocab = ConceptVocab() + concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64) + self.assertEqual(concept_ids.count(vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]), 1) + + def test_cehr_sequence_shapes(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + patient = ds.get_patient("p-synth-1") + vocab = ConceptVocab() + concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments = ( + build_cehr_sequences(patient, vocab, max_len=32) + ) + n = len(concept_ids) + self.assertEqual(len(token_types), n) + self.assertEqual(len(time_stamps), n) + self.assertEqual(len(ages), n) + self.assertEqual(len(visit_orders), n) + self.assertEqual(len(visit_segments), n) + self.assertGreater(n, 0) + + def test_build_cehr_max_len_zero_no_clinical_tokens(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + ], + ) + vocab = ConceptVocab() + c, _, _, _, _, vs = build_cehr_sequences(patient, vocab, max_len=0) + self.assertEqual(c, []) + self.assertEqual(vs, []) + + def test_visit_segments_alternate_by_visit_index(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e0", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e0", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20", + }, + ], + ) + vocab = ConceptVocab() + _, _, _, _, _, visit_segments = build_cehr_sequences(patient, vocab, max_len=64) + self.assertEqual(visit_segments, [0, 0, 1, 1]) + + def test_unlinked_visit_idx_matches_sequential_counter(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": None, + "encounter/encounter_id": "e_bad", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-03-01T10:00:00", + "encounter/encounter_id": "e_ok", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-05T11:00:00", + "condition/encounter_id": "e_ok", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], + ) + vocab = ConceptVocab() + concept_ids, _, _, _, visit_orders, visit_segments = build_cehr_sequences( + patient, vocab, max_len=64 + ) + i10 = vocab["http://hl7.org/fhir/sid/icd-10-cm|I10"] + z00 = vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"] + i_link = concept_ids.index(i10) + i_free = concept_ids.index(z00) + self.assertEqual(visit_orders[i_link], visit_orders[i_free]) + self.assertEqual(visit_segments[i_link], visit_segments[i_free]) + + def test_medication_request_uses_medication_codeable_concept(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T12:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222", + }, + ], + ) + vocab = ConceptVocab() + c, *_ = build_cehr_sequences(patient, vocab, max_len=64) + ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111" + kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222" + self.assertNotEqual(vocab[ka], vocab[kb]) + self.assertEqual(c.count(vocab[ka]), 1) + self.assertEqual(c.count(vocab[kb]), 1) + + def test_medication_request_medication_reference_token(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "MedicationRequest/reference|med-abc", + }, + ], + ) + vocab = ConceptVocab() + c, *_ = build_cehr_sequences(patient, vocab, max_len=64) + key = "MedicationRequest/reference|med-abc" + self.assertIn(vocab[key], c) + self.assertEqual(c.count(vocab[key]), 1) + + def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "a|1", + }, + { + "patient_id": "p1", + "event_type": "observation", + "timestamp": "2020-06-01T12:00:00", + "observation/encounter_id": "e1", + "observation/concept_key": "b|2", + }, + ], + ) + events = collect_cehr_timeline_events(patient) + self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mpf_task.py b/tests/core/test_mpf_task.py new file mode 100644 index 000000000..d130e2d83 --- /dev/null +++ b/tests/core/test_mpf_task.py @@ -0,0 +1,88 @@ +import shutil +import tempfile +import unittest +from pathlib import Path + +from pyhealth.datasets import MIMIC4FHIRDataset +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + +from tests.core.mimic4_fhir_ndjson_fixtures import write_two_class_ndjson + + +class TestMPFClinicalPredictionTask(unittest.TestCase): + def _two_patient_ds(self) -> MIMIC4FHIRDataset: + tmp = tempfile.mkdtemp() + self.addCleanup(lambda p=tmp: shutil.rmtree(p, ignore_errors=True)) + write_two_class_ndjson(Path(tmp)) + return MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp + ) + + def test_max_len_validation(self) -> None: + with self.assertRaises(ValueError): + MPFClinicalPredictionTask(max_len=1, use_mpf=True) + + def test_mpf_sets_boundary_tokens(self) -> None: + task = MPFClinicalPredictionTask(max_len=32, use_mpf=True) + ds = self._two_patient_ds() + samples = ds.gather_samples(task) + vocab = ds.vocab + self.assertGreater(len(samples), 0) + s0 = samples[0] + mor = vocab[""] + reg = vocab[""] + pad_id = vocab.pad_id + ids = s0["concept_ids"] + first = next(i for i, x in enumerate(ids) if x != pad_id) + last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id) + self.assertEqual(ids[first], mor) + self.assertEqual(ids[last_nz], reg) + self.assertEqual(ids[-1], reg) + + def test_no_mpf_uses_cls_reg(self) -> None: + task = MPFClinicalPredictionTask(max_len=32, use_mpf=False) + ds = self._two_patient_ds() + samples = ds.gather_samples(task) + vocab = ds.vocab + s0 = samples[0] + cls_id = vocab[""] + reg_id = vocab[""] + pad_id = vocab.pad_id + ids = s0["concept_ids"] + first = next(i for i, x in enumerate(ids) if x != pad_id) + last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id) + self.assertEqual(ids[first], cls_id) + self.assertEqual(ids[last_nz], reg_id) + self.assertEqual(ids[-1], reg_id) + + def test_schema_keys(self) -> None: + task = MPFClinicalPredictionTask(max_len=16, use_mpf=True) + ds = self._two_patient_ds() + samples = ds.gather_samples(task) + for k in task.input_schema: + self.assertIn(k, samples[0]) + self.assertIn("label", samples[0]) + + def test_max_len_two_keeps_boundary_tokens(self) -> None: + """``clinical_cap=0`` must yield ``[, ]`` left-padded, not truncated.""" + + task = MPFClinicalPredictionTask(max_len=2, use_mpf=True) + ds = self._two_patient_ds() + samples = ds.gather_samples(task) + vocab = ds.vocab + mor = vocab[""] + reg = vocab[""] + pad_id = vocab.pad_id + for s in samples: + ids = s["concept_ids"] + first = next(i for i, x in enumerate(ids) if x != pad_id) + last_nz = next( + i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id + ) + self.assertEqual(ids[first], mor) + self.assertEqual(ids[last_nz], reg) + self.assertEqual(ids[-1], reg) + + +if __name__ == "__main__": + unittest.main()