diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index b02439d26..ea772bfec 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -224,6 +224,7 @@ Available Datasets
datasets/pyhealth.datasets.SampleDataset
datasets/pyhealth.datasets.MIMIC3Dataset
datasets/pyhealth.datasets.MIMIC4Dataset
+ datasets/pyhealth.datasets.MIMIC4FHIRDataset
datasets/pyhealth.datasets.MedicalTranscriptionsDataset
datasets/pyhealth.datasets.CardiologyDataset
datasets/pyhealth.datasets.eICUDataset
diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst
new file mode 100644
index 000000000..1a19fc5e9
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst
@@ -0,0 +1,70 @@
+pyhealth.datasets.MIMIC4FHIRDataset
+=====================================
+
+`MIMIC-IV on FHIR `_ NDJSON ingest
+for CEHR-style token sequences used with
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and
+:class:`~pyhealth.models.EHRMambaCEHR`.
+
+YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml``. Unlike the
+earlier nested-object approach, the YAML now declares a normal ``tables:``
+schema for flattened FHIR resources (``patient``, ``encounter``, ``condition``,
+``observation``, ``medication_request``, ``procedure``). The class subclasses
+:class:`~pyhealth.datasets.BaseDataset` and builds a standard Polars
+``global_event_df`` backed by cached Parquet (``global_event_df.parquet/part-*.parquet``),
+same tabular path as other datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`,
+:meth:`iter_patients`, :meth:`get_patient`, etc.
+
+**Ingest (out-of-core).** Matching ``*.ndjson`` / ``*.ndjson.gz`` files are read
+**line by line**; each resource is normalized into a flattened per-resource
+Parquet table under ``cache/flattened_tables/``. Those tables are then fed
+through the regular YAML-driven :class:`~pyhealth.datasets.BaseDataset` loader to
+materialize ``global_event_df``. This keeps FHIR aligned with PyHealth's usual
+table-first pipeline instead of reparsing nested JSON per patient downstream.
+
+**``max_patients``.** When set, the loader selects the first *N* patient ids after
+a **sorted** ``unique`` over the flattened patient table, filters every
+normalized table to that cohort, and then builds ``global_event_df`` from the
+filtered tables. Ingest still scans all matching NDJSON once unless you also
+override ``glob_patterns`` / ``glob_pattern`` (defaults skip non-flattened PhysioNet shards).
+
+**Downstream memory (still important).** Streaming ingest avoids loading the
+entire NDJSON corpus into RAM at once, but other steps can still be heavy on
+large cohorts: ``global_event_df`` materialization, MPF vocabulary warmup, and
+:meth:`set_task` still walk patients and samples; training needs RAM/VRAM for the
+model and batches. For a **full** PhysioNet tree, plan for **large disk**
+(flattened tables plus event cache), **comfortable system RAM** for Polars/PyArrow
+and task pipelines, and restrict ``glob_patterns`` / ``glob_pattern`` or ``max_patients`` when
+prototyping on a laptop.
+
+**Recommended hardware (informal)**
+
+Order-of-magnitude guides, not guarantees. Ingest footprint is **much smaller**
+than “load everything into Python”; wall time still grows with **decompressed
+NDJSON volume** and the amount of flattened table data produced.
+
+* **Smoke / CI**
+ Small on-disk fixtures (see tests and ``examples/mimic4fhir_mpf_ehrmamba.py``):
+ a recent laptop is sufficient.
+
+* **Laptop-scale real FHIR subset**
+ A **narrow** ``glob_patterns`` / ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps
+ cache and task passes manageable. **≥ 16 GB** system RAM is a practical
+ comfort target for Polars + trainer + OS; validate GPU **VRAM** for your
+ ``max_len`` and batch size.
+
+* **Full default globs on a complete export**
+ Favor **workstations or servers** with **fast SSD**, **large disk**, and
+ **ample RAM** for downstream steps—not because NDJSON is fully buffered in
+ memory during ingest, but because total work and caches still scale with the
+ full dataset.
+
+.. autoclass:: pyhealth.datasets.MIMIC4FHIRDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: pyhealth.datasets.ConceptVocab
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7368dec94..c7e9f2729 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -185,6 +185,7 @@ API Reference
models/pyhealth.models.MoleRec
models/pyhealth.models.Deepr
models/pyhealth.models.EHRMamba
+ models/pyhealth.models.EHRMambaCEHR
models/pyhealth.models.JambaEHR
models/pyhealth.models.ContraWR
models/pyhealth.models.SparcNet
diff --git a/docs/api/models/pyhealth.models.EHRMambaCEHR.rst b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst
new file mode 100644
index 000000000..79466cad3
--- /dev/null
+++ b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst
@@ -0,0 +1,12 @@
+pyhealth.models.EHRMambaCEHR
+===================================
+
+EHRMambaCEHR applies CEHR-style embeddings (:class:`~pyhealth.models.cehr_embeddings.MambaEmbeddingsForCEHR`)
+and a stack of :class:`~pyhealth.models.MambaBlock` layers to a single FHIR token stream, for use with
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and
+:class:`~pyhealth.datasets.mimic4_fhir.MIMIC4FHIRDataset`.
+
+.. autoclass:: pyhealth.models.EHRMambaCEHR
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 399b8f1aa..83790ca44 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -214,6 +214,7 @@ Available Tasks
Drug Recommendation
Length of Stay Prediction
Medical Transcriptions Classification
+ MPF Clinical Prediction (FHIR)
Mortality Prediction (Next Visit)
Mortality Prediction (StageNet MIMIC-IV)
Patient Linkage (MIMIC-III)
diff --git a/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst
new file mode 100644
index 000000000..ff66deb08
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst
@@ -0,0 +1,12 @@
+pyhealth.tasks.mpf_clinical_prediction
+======================================
+
+Multitask Prompted Fine-tuning (MPF) style binary clinical prediction on FHIR
+token timelines, paired with :class:`~pyhealth.datasets.MIMIC4FHIRDataset` and
+:class:`~pyhealth.models.EHRMambaCEHR`. Based on CEHR / EHRMamba ideas; see the
+paper linked in the course replication PR.
+
+.. autoclass:: pyhealth.tasks.MPFClinicalPredictionTask
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py
new file mode 100644
index 000000000..a80d9c43c
--- /dev/null
+++ b/examples/mimic4fhir_mpf_ehrmamba.py
@@ -0,0 +1,842 @@
+"""EHRMambaCEHR on MIMIC-IV FHIR NDJSON with MPF clinical prediction (ablations).
+
+Replication target: EHRMamba / CEHR-style modeling on tokenized FHIR timelines
+(e.g. `arXiv:2405.14567 `_). This script is
+runnable end-to-end on **synthetic** NDJSON (``--quick-test``) or on
+credentialled MIMIC-IV on FHIR from PhysioNet.
+
+Experimental setup (for write-ups / PR):
+ * **Data**: Synthetic two-patient NDJSON (``--quick-test``) or disk NDJSON
+ under ``MIMIC4_FHIR_ROOT`` / ``--fhir-root``.
+ * **Task ablations**: ``max_len`` (context window), ``use_mpf`` vs generic
+ ````/```` boundaries (``--no-mpf``).
+ * **Model ablations**: ``hidden_dim`` (embedding width); optional dropout
+ fixed at 0.1 in this script.
+ * **Train**: Adam via :class:`~pyhealth.trainer.Trainer`, monitor ROC-AUC,
+ report test ROC-AUC / PR-AUC.
+
+**Ablation mode** (``--ablation``): sweeps a small grid on synthetic data only,
+trains 1 epoch per config, and prints a comparison table. Use this to document
+how task/model knobs affect metrics on the minimal fixture before scaling to
+real FHIR.
+
+**Findings** (fill in after your runs; synthetic runs are noisy):
+ On ``--quick-test`` data, longer ``max_len`` and MPF specials typically
+ change logits enough to move AUC slightly; real MIMIC-IV FHIR runs are
+ needed for conclusive comparisons. Paste your table from ``--ablation``
+ into the PR description.
+
+**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON into
+flattened per-resource Parquet tables (bounded RAM during ingest). This example trains via
+``dataset.set_task(MPFClinicalPredictionTask)`` → LitData-backed
+:class:`~pyhealth.datasets.sample_dataset.SampleDataset` →
+:class:`~pyhealth.trainer.Trainer` (PyHealth’s standard path), instead of
+materializing all samples with ``gather_samples()``. Prefer ``--max-patients`` to
+ bound ingest when possible. Very large cohorts still need RAM/disk for task
+caches and MPF vocabulary warmup.
+
+**Offline flattened tables (NDJSON normalization already done):** pass
+``--prebuilt-global-event-dir`` pointing at a directory containing the normalized
+FHIR tables (``patient.parquet``, ``encounter.parquet``, ``condition.parquet``,
+etc.). The example seeds ``flattened_tables/`` under the usual PyHealth cache UUID,
+then lets :class:`~pyhealth.datasets.BaseDataset` rebuild
+``global_event_df.parquet/`` from those tables — the downstream path is still
+``global_event_df`` → :class:`~pyhealth.data.Patient` →
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` →
+:class:`~pyhealth.trainer.Trainer``. Use ``--fhir-root`` / ``--glob-pattern`` /
+``--max-patients -1`` matching the ingest fingerprint.
+``--train-patient-cap`` restricts task transforms via ``task.pre_filter`` using a
+label-aware deterministic patient subset. The full ``unique_patient_ids`` scan and MPF vocab warmup
+in the dataset still walk the cached cohort.
+
+**Approximate minimum specs** (``--quick-test``, CPU, synthetic 2-patient
+fixture; measured once on macOS/arm64 with ``/usr/bin/time -l``): peak RSS
+~**600–700 MiB**, wall **~10–15 s** for two short epochs. Real NDJSON/GZ at scale
+needs proportionally more RAM, disk, and time; GPU helps training, not the
+current all-in-RAM parse.
+
+Usage:
+ cd PyHealth && PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test
+ PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test --ablation
+ export MIMIC4_FHIR_ROOT=/path/to/fhir
+ pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py --fhir-root "$MIMIC4_FHIR_ROOT"
+
+ # Prebuilt flattened FHIR tables (skip NDJSON normalization); cap patients for a smoke train
+ pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py \\
+ --prebuilt-global-event-dir /path/to/flattened_table_dir \\
+ --fhir-root /same/as/ndjson/ingest/root --glob-pattern 'Mimic*.ndjson.gz' --max-patients -1 \\
+ --train-patient-cap 2048 --epochs 2 \\
+ --ntfy-url 'https://ntfy.sh/your-topic'
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import random
+import re
+import shutil
+import sys
+import tempfile
+import time
+import urllib.error
+import urllib.request
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+_parser = argparse.ArgumentParser(description="EHRMambaCEHR on MIMIC-IV FHIR (MPF)")
+_parser.add_argument(
+ "--gpu",
+ type=int,
+ default=None,
+ help="GPU index; sets CUDA_VISIBLE_DEVICES.",
+)
+_parser.add_argument(
+ "--fhir-root",
+ type=str,
+ default=None,
+ help="Root directory with NDJSON (default: MIMIC4_FHIR_ROOT env).",
+)
+_parser.add_argument(
+ "--glob-pattern",
+ type=str,
+ default=None,
+ help=(
+ "Override glob for NDJSON/NDJSON.GZ (default: yaml **/*.ndjson.gz). "
+ "Use a narrow pattern to limit ingest time and cache size."
+ ),
+)
+_parser.add_argument(
+ "--max-len",
+ type=int,
+ default=512,
+ help="Sequence length ablation (e.g. 512 / 1024 / 2048 per proposal).",
+)
+_parser.add_argument(
+ "--no-mpf",
+ action="store_true",
+ help="Ablation: use generic CLS/REG specials instead of task MPF tokens.",
+)
+_parser.add_argument(
+ "--hidden-dim",
+ type=int,
+ default=128,
+ help="Embedding / hidden size (model ablation).",
+)
+_parser.add_argument(
+ "--lr",
+ type=float,
+ default=1e-3,
+ help="Adam learning rate (trainer.train optimizer_params).",
+)
+_parser.add_argument(
+ "--quick-test",
+ action="store_true",
+ help="Use synthetic in-memory FHIR lines only (no disk root).",
+)
+_parser.add_argument(
+ "--ablation",
+ action="store_true",
+ help="Run a small max_len × MPF × hidden_dim grid on synthetic data; print table.",
+)
+_parser.add_argument(
+ "--epochs",
+ type=int,
+ default=None,
+ help="Training epochs (default: 2 with --quick-test, else 20).",
+)
+_parser.add_argument(
+ "--max-patients",
+ type=int,
+ default=500,
+ help=(
+ "Fingerprint for cache dir: cap patients during normalization (-1 = full cohort, "
+ "match an uncapped NDJSON→flattened-table export)."
+ ),
+)
+_parser.add_argument(
+ "--prebuilt-global-event-dir",
+ type=str,
+ default=None,
+ help=(
+ "Directory with normalized flattened FHIR tables (*.parquet). Seeds "
+ "cache/flattened_tables/ so training skips NDJSON normalization "
+ "(downstream unchanged: Patient + MPF + Trainer)."
+ ),
+)
+_parser.add_argument(
+ "--ingest-num-shards",
+ type=int,
+ default=None,
+ help="Compatibility no-op: retained for CLI stability with older runs.",
+)
+_parser.add_argument(
+ "--train-patient-cap",
+ type=int,
+ default=None,
+ help=(
+ "After cache is ready, only build samples from a deterministic label-aware "
+ "patient subset of size N (reduces train time; unique-id scan of "
+ "global_event_df still runs once)."
+ ),
+)
+_parser.add_argument(
+ "--ntfy-url",
+ type=str,
+ default=None,
+ help="POST notification when main() finishes (e.g. https://ntfy.sh/topic).",
+)
+_parser.add_argument(
+ "--loss-plot-path",
+ type=str,
+ default=None,
+ help="Write loss curve PNG here (default: alongside Trainer log under output/).",
+)
+_parser.add_argument(
+ "--cache-dir",
+ type=str,
+ default=None,
+ help="PyHealth dataset cache parent (UUID subdir added by MIMIC4FHIRDataset).",
+)
+_parser.add_argument(
+ "--task-num-workers",
+ type=int,
+ default=None,
+ help=(
+ "Workers for LitData task/processor transforms (default: dataset "
+ "``num_workers``, usually 1)."
+ ),
+)
+_pre_args, _ = _parser.parse_known_args()
+if _pre_args.gpu is not None:
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(_pre_args.gpu)
+
+import torch
+
+import polars as pl
+
+from pyhealth.datasets import MIMIC4FHIRDataset, get_dataloader
+from pyhealth.datasets.fhir_cehr import infer_mortality_label
+from pyhealth.models import EHRMambaCEHR
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+from pyhealth.trainer import Trainer
+
+
+class PatientCappedMPFTask(MPFClinicalPredictionTask):
+ """Example-only: limit task transform to an explicit patient_id allow-list."""
+
+ def __init__(
+ self,
+ *,
+ max_len: int,
+ use_mpf: bool,
+ patient_ids_allow: List[str],
+ ) -> None:
+ super().__init__(max_len=max_len, use_mpf=use_mpf)
+ self.patient_ids_allow = patient_ids_allow
+
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id").is_in(self.patient_ids_allow))
+
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+SEED = 42
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+BATCH_SIZE = 8
+EPOCHS = 20
+SPLIT_RATIOS = (0.7, 0.1, 0.2)
+
+
+def _max_patients_arg(v: int) -> Optional[int]:
+ return None if v is not None and v < 0 else v
+
+
+def _seed_flattened_table_cache(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None:
+ """Copy normalized per-resource parquet tables into the dataset cache."""
+
+ tables = sorted(prebuilt_dir.glob("*.parquet"))
+ if not tables:
+ raise FileNotFoundError(
+ f"No *.parquet tables under {prebuilt_dir} — expected flattened FHIR tables."
+ )
+ prepared = ds.prepared_tables_dir
+ if prepared.exists() and any(prepared.glob("*.parquet")):
+ return
+ prepared.mkdir(parents=True, exist_ok=True)
+ for src in tables:
+ dest = prepared / src.name
+ if dest.exists():
+ continue
+ try:
+ os.link(src, dest)
+ except OSError:
+ shutil.copy2(src, dest)
+
+
+def _parse_train_losses_from_log(log_path: Path) -> List[float]:
+ """Mean training loss per epoch from Trainer file log."""
+
+ if not log_path.is_file():
+ return []
+ text = log_path.read_text(encoding="utf-8", errors="replace")
+ losses: List[float] = []
+ lines = text.splitlines()
+ for i, line in enumerate(lines):
+ if "--- Train epoch-" in line and i + 1 < len(lines):
+ m = re.search(r"loss:\s*([0-9.eE+-]+)", lines[i + 1])
+ if m:
+ losses.append(float(m.group(1)))
+ return losses
+
+
+def _write_loss_plot(losses: List[float], out_path: Path) -> None:
+ if not losses:
+ return
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ try:
+ import matplotlib.pyplot as plt
+ except ImportError:
+ csv_path = out_path.with_suffix(".csv")
+ csv_path.write_text(
+ "epoch,train_loss_mean\n"
+ + "\n".join(f"{i},{v}" for i, v in enumerate(losses)),
+ encoding="utf-8",
+ )
+ print(
+ "matplotlib not installed; wrote", csv_path, "(pip install matplotlib for PNG)"
+ )
+ return
+ plt.figure(figsize=(6, 3.5))
+ plt.plot(range(len(losses)), losses, marker="o", linewidth=1)
+ plt.xlabel("epoch")
+ plt.ylabel("mean train loss")
+ plt.title("EHRMambaCEHR training loss (MPF)")
+ plt.grid(True, alpha=0.3)
+ plt.tight_layout()
+ plt.savefig(out_path, dpi=120)
+ plt.close()
+ print("loss plot:", out_path)
+
+
+def _ntfy(url: str, title: str, message: str) -> None:
+ try:
+ req = urllib.request.Request(
+ url,
+ data=message.encode("utf-8"),
+ method="POST",
+ )
+ req.add_header("Title", title[:200])
+ with urllib.request.urlopen(req, timeout=60) as resp:
+ if resp.status >= 400:
+ print("ntfy HTTP", resp.status, file=sys.stderr)
+ except urllib.error.URLError as e:
+ print("ntfy failed:", e, file=sys.stderr)
+
+
+def _quick_test_ndjson_dir() -> str:
+ """Write two-patient synthetic NDJSON; returns temp directory (caller cleans up)."""
+
+ from pyhealth.datasets.fhir_ingest import synthetic_mpf_two_patient_ndjson_text
+
+ tmp = tempfile.mkdtemp(prefix="pyhealth_mimic4_fhir_quick_")
+ Path(tmp, "fixture.ndjson").write_text(
+ synthetic_mpf_two_patient_ndjson_text(),
+ encoding="utf-8",
+ )
+ return tmp
+
+
+def _patient_label(ds: MIMIC4FHIRDataset, patient_id: str) -> int:
+ patient = ds.get_patient(patient_id)
+ return int(infer_mortality_label(patient))
+
+
+def _ensure_binary_label_coverage(ds: MIMIC4FHIRDataset) -> None:
+ found: Dict[int, str] = {}
+ scanned = 0
+ for patient_id in ds.unique_patient_ids:
+ label = _patient_label(ds, patient_id)
+ scanned += 1
+ found.setdefault(label, patient_id)
+ if len(found) == 2:
+ print(
+ "label preflight:",
+ {"scanned_patients": scanned, "example_patient_ids": found},
+ )
+ return
+ raise SystemExit(
+ "Binary mortality example found only one label in the available cohort; "
+ "cannot build a valid binary training set from this cache."
+ )
+
+
+def _select_patient_ids_for_cap(
+ ds: MIMIC4FHIRDataset, requested_cap: int
+) -> List[str]:
+ patient_ids = ds.unique_patient_ids
+ if not patient_ids:
+ return []
+
+ desired = max(2, requested_cap)
+ desired = min(desired, len(patient_ids))
+ if desired < requested_cap:
+ print(
+ f"train_patient_cap requested {requested_cap}, but only {desired} patients are available."
+ )
+ elif requested_cap < 2:
+ print(
+ f"train_patient_cap={requested_cap} is too small for binary labels; using {desired}."
+ )
+
+ encountered: List[str] = []
+ label_by_patient_id: Dict[str, int] = {}
+ first_by_label: Dict[int, str] = {}
+ for patient_id in patient_ids:
+ label = _patient_label(ds, patient_id)
+ encountered.append(patient_id)
+ label_by_patient_id[patient_id] = label
+ first_by_label.setdefault(label, patient_id)
+ if len(encountered) >= desired and len(first_by_label) == 2:
+ break
+
+ if len(first_by_label) < 2:
+ raise SystemExit(
+ "Unable to satisfy --train-patient-cap with both binary labels from the "
+ "available cohort. Use a different cache/export or remove the cap."
+ )
+
+ selected = encountered[:desired]
+ selected_labels = {label_by_patient_id[pid] for pid in selected}
+ if len(selected_labels) == 1:
+ missing_label = 1 - next(iter(selected_labels))
+ replacement = first_by_label[missing_label]
+ for idx in range(len(selected) - 1, -1, -1):
+ if label_by_patient_id[selected[idx]] != missing_label:
+ selected[idx] = replacement
+ break
+
+ counts = {
+ 0: sum(1 for pid in selected if label_by_patient_id[pid] == 0),
+ 1: sum(1 for pid in selected if label_by_patient_id[pid] == 1),
+ }
+ print(
+ "train_patient_cap selection:",
+ {
+ "requested": requested_cap,
+ "selected": len(selected),
+ "scanned_patients": len(encountered),
+ "label_counts": counts,
+ },
+ )
+ return selected
+
+
+def _sample_label(sample: Dict[str, Any]) -> int:
+ label = sample["label"]
+ if isinstance(label, torch.Tensor):
+ return int(label.reshape(-1)[0].item())
+ return int(label)
+
+
+def _split_counts(n: int) -> List[int]:
+ if n < 3:
+ raise ValueError("Need at least 3 samples for three-way stratified split.")
+ counts = [1, 1, 1]
+ remaining = n - 3
+ raw = [ratio * remaining for ratio in SPLIT_RATIOS]
+ floors = [int(x) for x in raw]
+ for i, floor in enumerate(floors):
+ counts[i] += floor
+ assigned = sum(counts)
+ order = sorted(
+ range(3),
+ key=lambda i: raw[i] - floors[i],
+ reverse=True,
+ )
+ for i in order:
+ if assigned >= n:
+ break
+ counts[i] += 1
+ assigned += 1
+ counts[0] += n - assigned
+ return counts
+
+
+def _split_sample_dataset_for_binary_metrics(sample_ds: Any) -> tuple[Any, Any, Any]:
+ if len(sample_ds) < 8:
+ print("sample count < 8; reusing the full dataset for train/val/test.")
+ return sample_ds, sample_ds, sample_ds
+
+ label_to_indices: Dict[int, List[int]] = {0: [], 1: []}
+ for idx in range(len(sample_ds)):
+ label_to_indices[_sample_label(sample_ds[idx])].append(idx)
+
+ label_counts = {label: len(indices) for label, indices in label_to_indices.items()}
+ min_count = min(label_counts.values())
+ if min_count < 3:
+ print(
+ "label distribution too small for disjoint binary train/val/test splits; "
+ "reusing the full dataset for train/val/test.",
+ label_counts,
+ )
+ return sample_ds, sample_ds, sample_ds
+
+ rng = random.Random(SEED)
+ split_indices: List[List[int]] = [[], [], []]
+ for indices in label_to_indices.values():
+ shuffled = indices[:]
+ rng.shuffle(shuffled)
+ n_train, n_val, n_test = _split_counts(len(shuffled))
+ split_indices[0].extend(shuffled[:n_train])
+ split_indices[1].extend(shuffled[n_train : n_train + n_val])
+ split_indices[2].extend(shuffled[n_train + n_val : n_train + n_val + n_test])
+
+ for indices in split_indices:
+ indices.sort()
+
+ split_counts = []
+ for indices in split_indices:
+ split_counts.append(
+ {
+ 0: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 0),
+ 1: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 1),
+ "n": len(indices),
+ }
+ )
+ print(
+ "binary stratified split counts:",
+ {"train": split_counts[0], "val": split_counts[1], "test": split_counts[2]},
+ )
+ return (
+ sample_ds.subset(split_indices[0]),
+ sample_ds.subset(split_indices[1]),
+ sample_ds.subset(split_indices[2]),
+ )
+
+
+def _build_loaders_from_sample_dataset(
+ sample_ds: Any,
+ vocab_size: int,
+) -> tuple[Any, Any, Any, Any, int]:
+ train_ds, val_ds, test_ds = _split_sample_dataset_for_binary_metrics(sample_ds)
+ train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
+ val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
+ test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
+ return sample_ds, train_loader, val_loader, test_loader, vocab_size
+
+
+def run_single_train(
+ *,
+ fhir_root: str,
+ max_len: int,
+ use_mpf: bool,
+ hidden_dim: int,
+ epochs: int,
+ lr: float = 1e-3,
+ glob_pattern: str = "*.ndjson",
+ cache_dir: Optional[str] = None,
+ dataset_max_patients: Optional[int] = 500,
+ ingest_num_shards: Optional[int] = None,
+ prebuilt_global_event_dir: Optional[str] = None,
+ train_patient_cap: Optional[int] = None,
+) -> Dict[str, float]:
+ """Train/eval one configuration; returns test metrics (floats)."""
+
+ ds_kw: Dict[str, Any] = {
+ "root": fhir_root,
+ "glob_pattern": glob_pattern,
+ "cache_dir": cache_dir,
+ "max_patients": dataset_max_patients,
+ }
+ if ingest_num_shards is not None:
+ ds_kw["ingest_num_shards"] = ingest_num_shards
+ ds = MIMIC4FHIRDataset(**ds_kw)
+ if prebuilt_global_event_dir:
+ _seed_flattened_table_cache(
+ Path(prebuilt_global_event_dir).expanduser().resolve(), ds
+ )
+ if train_patient_cap is not None:
+ allow = _select_patient_ids_for_cap(ds, train_patient_cap)
+ task: MPFClinicalPredictionTask = PatientCappedMPFTask(
+ max_len=max_len,
+ use_mpf=use_mpf,
+ patient_ids_allow=allow,
+ )
+ else:
+ _ensure_binary_label_coverage(ds)
+ task = MPFClinicalPredictionTask(max_len=max_len, use_mpf=use_mpf)
+ sample_ds = ds.set_task(task, num_workers=1)
+ vocab_size = ds.vocab.vocab_size
+ sample_ds, train_l, val_l, test_l, vocab_size = _build_loaders_from_sample_dataset(
+ sample_ds, vocab_size
+ )
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=hidden_dim,
+ num_layers=2,
+ dropout=0.1,
+ )
+ trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE)
+ trainer.train(
+ train_dataloader=train_l,
+ val_dataloader=val_l,
+ epochs=epochs,
+ monitor="roc_auc",
+ optimizer_params={"lr": lr},
+ )
+ results = trainer.evaluate(test_l)
+ return {k: float(v) for k, v in results.items()}
+
+
+def run_ablation_table(*, lr: float = 1e-3) -> None:
+ """Task × model grid on synthetic NDJSON (short runs for comparison)."""
+
+ # Ablations: context length, MPF vs CLS/REG, plus one hidden_dim pair.
+ grid = [
+ (32, True, 32),
+ (32, False, 32),
+ (96, True, 64),
+ (96, False, 64),
+ ]
+ tmp = _quick_test_ndjson_dir()
+ try:
+ print(
+ "Ablation (synthetic, 1 epoch each): max_len, use_mpf, hidden_dim, lr="
+ f"{lr} -> test roc_auc, pr_auc"
+ )
+ rows = []
+ t0 = time.perf_counter()
+ for max_len, use_mpf, hidden_dim in grid:
+ metrics = run_single_train(
+ fhir_root=tmp,
+ max_len=max_len,
+ use_mpf=use_mpf,
+ hidden_dim=hidden_dim,
+ epochs=1,
+ lr=lr,
+ cache_dir=tmp,
+ dataset_max_patients=500,
+ )
+ rows.append((max_len, use_mpf, hidden_dim, metrics))
+ print(
+ f" max_len={max_len} mpf={use_mpf} hid={hidden_dim} -> "
+ f"roc_auc={metrics['roc_auc']:.4f} pr_auc={metrics['pr_auc']:.4f}"
+ )
+ print("ablation_wall_s:", round(time.perf_counter() - t0, 2))
+ best = max(rows, key=lambda r: r[3]["roc_auc"])
+ print(
+ "best_by_roc_auc:",
+ {
+ "max_len": best[0],
+ "use_mpf": best[1],
+ "hidden_dim": best[2],
+ "metrics": best[3],
+ },
+ )
+ except Exception:
+ print(f"ablation: leaving scratch directory for debugging: {tmp}", file=sys.stderr)
+ raise
+ else:
+ shutil.rmtree(tmp, ignore_errors=True)
+
+
+def main() -> None:
+ args = _parser.parse_args()
+ status = "abort"
+ ntfy_detail = ""
+ try:
+ _main_train(args)
+ status = "ok"
+ ntfy_detail = "Training finished successfully."
+ except SystemExit as e:
+ status = "exit"
+ ntfy_detail = f"SystemExit {e.code!r}"
+ raise
+ except Exception as e:
+ status = "error"
+ ntfy_detail = f"{type(e).__name__}: {e}"[:3800]
+ raise
+ finally:
+ if args.ntfy_url and status in ("ok", "error"):
+ _ntfy(
+ args.ntfy_url,
+ "mimic-fhir-train OK" if status == "ok" else "mimic-fhir-train FAIL",
+ ntfy_detail,
+ )
+
+
+def _main_train(args: argparse.Namespace) -> None:
+ fhir_root = args.fhir_root or os.environ.get("MIMIC4_FHIR_ROOT")
+ quick = args.quick_test
+ quick_test_tmp: Optional[str] = None
+ if args.epochs is not None:
+ epochs = args.epochs
+ else:
+ epochs = 2 if quick else EPOCHS
+
+ if args.ablation:
+ if not quick:
+ raise SystemExit("--ablation requires --quick-test (synthetic data only).")
+ run_ablation_table(lr=args.lr)
+ return
+
+ print("EHRMambaCEHR – MIMIC-IV FHIR (MPF clinical prediction)")
+ print("device:", DEVICE)
+ print("max_len:", args.max_len, "| use_mpf:", not args.no_mpf)
+ print("hidden_dim:", args.hidden_dim, "| lr:", args.lr)
+
+ sample_ds: Any
+ vocab: Any
+
+ if quick:
+ quick_test_tmp = _quick_test_ndjson_dir()
+ ds = MIMIC4FHIRDataset(
+ root=quick_test_tmp,
+ glob_pattern="*.ndjson",
+ cache_dir=quick_test_tmp,
+ max_patients=500,
+ )
+ try:
+ print(
+ "pipeline: synthetic NDJSON → flattened tables → global_event_df "
+ "→ set_task → SampleDataset → Trainer"
+ )
+ task = MPFClinicalPredictionTask(
+ max_len=args.max_len,
+ use_mpf=not args.no_mpf,
+ )
+ print("set_task (quick-test, num_workers=1)...")
+ t_task0 = time.perf_counter()
+ sample_ds = ds.set_task(task, num_workers=1)
+ print(
+ "set_task done: n_samples=",
+ len(sample_ds),
+ "wall_s=",
+ round(time.perf_counter() - t_task0, 2),
+ )
+ vocab = ds.vocab
+ except Exception:
+ print(
+ f"quick-test: leaving NDJSON/Parquet scratch at {quick_test_tmp}",
+ file=sys.stderr,
+ )
+ raise
+ else:
+ mp = _max_patients_arg(args.max_patients)
+ if not fhir_root or not os.path.isdir(fhir_root):
+ raise SystemExit(
+ "Set MIMIC4_FHIR_ROOT or pass --fhir-root to an existing directory "
+ "(NDJSON tree for ingest fingerprint, even when using --prebuilt-global-event-dir)."
+ )
+ ds_kw: Dict[str, Any] = {
+ "root": fhir_root,
+ "max_patients": mp,
+ "cache_dir": args.cache_dir,
+ }
+ if args.glob_pattern is not None:
+ ds_kw["glob_pattern"] = args.glob_pattern
+ if args.ingest_num_shards is not None:
+ ds_kw["ingest_num_shards"] = args.ingest_num_shards
+ ds = MIMIC4FHIRDataset(**ds_kw)
+ if args.prebuilt_global_event_dir:
+ pb = Path(args.prebuilt_global_event_dir).expanduser().resolve()
+ if not pb.is_dir():
+ raise SystemExit(f"--prebuilt-global-event-dir not a directory: {pb}")
+ print(
+ "pipeline: offline flattened FHIR tables → seed flattened table cache "
+ "→ global_event_df → set_task → SampleDataset → Trainer "
+ "(no NDJSON normalization)"
+ )
+ _seed_flattened_table_cache(pb, ds)
+ else:
+ print(
+ "pipeline: NDJSON root → MIMIC4FHIRDataset flattening → global_event_df "
+ "→ set_task → SampleDataset → Trainer"
+ )
+ print("glob_pattern:", ds.glob_pattern, "| max_patients fingerprint:", mp)
+ if args.train_patient_cap is not None:
+ print("train_patient_cap:", args.train_patient_cap)
+ allow = _select_patient_ids_for_cap(ds, args.train_patient_cap)
+ mpf_task: MPFClinicalPredictionTask = PatientCappedMPFTask(
+ max_len=args.max_len,
+ use_mpf=not args.no_mpf,
+ patient_ids_allow=allow,
+ )
+ print("task patient allow-list size:", len(allow))
+ else:
+ _ensure_binary_label_coverage(ds)
+ mpf_task = MPFClinicalPredictionTask(
+ max_len=args.max_len,
+ use_mpf=not args.no_mpf,
+ )
+ nw = args.task_num_workers
+ if nw is None:
+ nw = ds.num_workers
+ print(f"set_task (LitData task cache, num_workers={nw})...")
+ t_task0 = time.perf_counter()
+ sample_ds = ds.set_task(mpf_task, num_workers=nw)
+ print(
+ "set_task done: n_samples=",
+ len(sample_ds),
+ "wall_s=",
+ round(time.perf_counter() - t_task0, 2),
+ )
+ vocab = ds.vocab
+ print("fhir_root:", fhir_root)
+
+ try:
+ if len(sample_ds) == 0:
+ raise SystemExit(
+ "No training samples (0 patients or empty sequences). "
+ "PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (see default glob_patterns in "
+ "pyhealth/datasets/configs/mimic4_fhir.yaml). If your tree is plain *.ndjson, "
+ "construct MIMIC4FHIRDataset with glob_pattern='**/*.ndjson'."
+ )
+
+ sample_ds, train_loader, val_loader, test_loader, vocab_size = (
+ _build_loaders_from_sample_dataset(sample_ds, vocab.vocab_size)
+ )
+
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=args.hidden_dim,
+ num_layers=2,
+ dropout=0.1,
+ )
+ trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE)
+
+ t0 = time.perf_counter()
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ epochs=epochs,
+ monitor="roc_auc",
+ optimizer_params={"lr": args.lr},
+ )
+ results = trainer.evaluate(test_loader)
+ print("Test:", {k: float(v) for k, v in results.items()})
+ print("wall_s:", round(time.perf_counter() - t0, 1))
+ print("concept_vocab_size:", vocab.vocab_size)
+
+ log_txt = (
+ Path(trainer.exp_path) / "log.txt" if trainer.exp_path else None
+ )
+ if log_txt and log_txt.is_file():
+ losses = _parse_train_losses_from_log(log_txt)
+ print("train_loss_per_epoch:", losses)
+ plot_path = (
+ Path(args.loss_plot_path)
+ if args.loss_plot_path
+ else Path(trainer.exp_path) / "train_loss.png"
+ )
+ if trainer.exp_path:
+ _write_loss_plot(losses, plot_path)
+ finally:
+ if quick_test_tmp is not None:
+ shutil.rmtree(quick_test_tmp, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pixi.lock b/pixi.lock
index 0f11d28d7..26c1b2237 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -96,6 +96,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -217,6 +218,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
@@ -328,6 +330,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
@@ -440,6 +443,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
@@ -597,6 +601,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -727,6 +732,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -848,6 +854,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -970,6 +977,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1158,6 +1166,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1324,6 +1333,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1474,6 +1484,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1626,6 +1637,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1777,6 +1789,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1906,6 +1919,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -2026,6 +2040,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -2147,6 +2162,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -2224,6 +2240,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda
+ - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl
@@ -2240,6 +2257,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz
- pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/75/b4/b96bb66f6f8cc4669de44a158099b249c8159231d254ab6b092909388be5/fonttools-4.59.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl
@@ -2269,6 +2287,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
@@ -2286,6 +2305,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -2308,6 +2328,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/34/43/3f250ec28edff1c06ffaa25faddbe13ae85c11a9724894cbdcf89427de78/rdkit-2025.3.3-cp313-cp313-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/db/60/1eeca2074f5b87df394fccaa432ae3fc06c9c9bfa97c5051aed70e6e00c2/regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz
- pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/99/72/c86a4cd867816350fe8dee13f30222340b9cd6b96173955819a5561810c5/scikit_learn-1.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
@@ -2360,6 +2381,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/readline-8.2-h8382b9d_2.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/tk-8.6.13-noxft_h5688188_102.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda
+ - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl
@@ -2376,6 +2398,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz
- pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b5/57/7969af50b26408be12baa317c6147588db5b38af2759e6df94554dbc5fdb/fonttools-4.59.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl
@@ -2405,9 +2428,11 @@ environments:
- pypi: https://files.pythonhosted.org/packages/d3/68/93180dce57f684a61a88a45ed13047558ded2be46f03acb8dec6d7c513af/msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
@@ -2430,6 +2455,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/ff/5f/907a48c5f9b83302b4530605df1325963977fdf06753d3d8610d16c40197/rdkit-2025.3.3-cp313-cp313-manylinux_2_28_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/fc/fd/37868b75eaf63843165f1d2122ca6cb94bfc0271e4428cf58c0616786dce/regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz
- pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/e8/66/277967b29bd297538dc7a6ecfb1a7dce751beabd0d7f7a2233be7a4f7832/scikit_learn-1.7.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl
@@ -2472,6 +2498,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.2-h1d1bf99_2.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h892fb3f_2.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda
+ - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl
@@ -2488,6 +2515,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz
- pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f3/bb/390990e7c457d377b00890d9f96a3ca13ae2517efafb6609c1756e213ba4/fonttools-4.59.0-cp313-cp313-macosx_10_13_universal2.whl
@@ -2517,9 +2545,11 @@ environments:
- pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
@@ -2542,6 +2572,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/3b/0b/6ab0cc692b2890f4f7c74f6ffd4bba748dcb9312d5a7bd2328cb82204da1/rdkit-2025.3.3-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/09/c9/4e68181a4a652fb3ef5099e077faf4fd2a694ea6e0f806a7737aff9e758a/regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz
- pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b8/3b/11f1b4a2f5d2ab7da34ecc062b0bc301f2be024d110a6466726bec8c055c/safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/f3/f1df377d1bdfc3e3e2adc9c119c238b182293e6740df4cbeac6de2cc3e23/scikit_learn-1.7.1-cp313-cp313-macosx_12_0_arm64.whl
@@ -2585,6 +2616,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda
- conda: https://conda.anaconda.org/conda-forge/win-64/vc-14.3-h41ae7f8_26.conda
- conda: https://conda.anaconda.org/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda
+ - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl
@@ -2602,6 +2634,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz
- pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a0/ee/f626cd372932d828508137a79b85167fdcf3adab2e3bed433f295c596c6a/fonttools-4.59.0-cp313-cp313-win_amd64.whl
@@ -2630,9 +2663,11 @@ environments:
- pypi: https://files.pythonhosted.org/packages/74/07/1ed8277f8653c40ebc65985180b007879f6a836c525b3885dcc6448ae6cb/msgpack-1.1.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
@@ -2655,6 +2690,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/98/da/164e31b607c0cf22f1179cd15fa058780f940b21ec42ba3c9026c21897e3/rdkit-2025.3.3-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/45/94/bc295babb3062a731f52621cdc992d123111282e291abaf23faa413443ea/regex-2024.11.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz
- pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/e2/47/9291cfa1db1dae9880420d1e07dbc7e8dd4a7cdbc42eaba22512e6bde958/scikit_learn-1.7.1-cp313-cp313-win_amd64.whl
@@ -2776,6 +2812,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -2897,6 +2934,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
@@ -3008,6 +3046,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
@@ -3120,6 +3159,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
@@ -3213,6 +3253,11 @@ packages:
purls: []
size: 8191
timestamp: 1744137672556
+- pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl
+ name: absl-py
+ version: 2.4.0
+ sha256: 88476fd881ca8aab94ffa78b7b6c632a782ab3ba1cd19c9bd423abc4fb4cd28d
+ requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl
name: accelerate
version: 1.10.0
@@ -3958,6 +4003,11 @@ packages:
- pkg:pypi/editables?source=hash-mapping
size: 10828
timestamp: 1733208220327
+- pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz
+ name: editdistance
+ version: 0.8.1
+ sha256: d1cdf80a5d5014b0c9126a69a42ce55a457b457f6986ff69ca98e4fe4d2d8fed
+ requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl
name: einops
version: 0.8.2
@@ -5913,6 +5963,32 @@ packages:
- pkg:pypi/nh3?source=hash-mapping
size: 584955
timestamp: 1756737407424
+- pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl
+ name: nltk
+ version: 3.9.4
+ sha256: f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f
+ requires_dist:
+ - click
+ - joblib
+ - regex>=2021.8.3
+ - tqdm
+ - numpy ; extra == 'machine-learning'
+ - python-crfsuite ; extra == 'machine-learning'
+ - scikit-learn ; extra == 'machine-learning'
+ - scipy ; extra == 'machine-learning'
+ - matplotlib ; extra == 'plot'
+ - pyparsing ; extra == 'tgrep'
+ - twython ; extra == 'twitter'
+ - requests ; extra == 'corenlp'
+ - scipy ; extra == 'all'
+ - python-crfsuite ; extra == 'all'
+ - pyparsing ; extra == 'all'
+ - requests ; extra == 'all'
+ - numpy ; extra == 'all'
+ - scikit-learn ; extra == 'all'
+ - twython ; extra == 'all'
+ - matplotlib ; extra == 'all'
+ requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
name: numpy
version: 2.2.6
@@ -6100,6 +6176,26 @@ packages:
purls: []
size: 9327033
timestamp: 1751392489008
+- pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
+ name: orjson
+ version: 3.11.7
+ sha256: b9f95dcdea9d4f805daa9ddf02617a89e484c6985fa03055459f90e87d7a0757
+ requires_python: '>=3.10'
+- pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
+ name: orjson
+ version: 3.11.7
+ sha256: 814be4b49b228cfc0b3c565acf642dd7d13538f966e3ccde61f4f55be3e20785
+ requires_python: '>=3.10'
+- pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
+ name: orjson
+ version: 3.11.7
+ sha256: 1d98b30cc1313d52d4af17d9c3d307b08389752ec5f2e5febdfada70b0f8c733
+ requires_python: '>=3.10'
+- pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
+ name: orjson
+ version: 3.11.7
+ sha256: a12b80df61aab7b98b490fe9e4879925ba666fccdfcd175252ce4d9035865ace
+ requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
name: outdated
version: 0.2.2
@@ -7030,7 +7126,7 @@ packages:
- pypi: ./
name: pyhealth
version: 2.0.0
- sha256: f07719f9dceb759c35507216c8033d2f915d241418d4fad2ab51b37c0e73260f
+ sha256: a11efbf6fe99193820e0ee48e843cfb89526030b1d5e48ddc71c8fa543242572
requires_dist:
- torch~=2.7.1
- torchvision
@@ -7055,6 +7151,11 @@ packages:
- more-itertools~=10.8.0
- einops>=0.8.0
- linear-attention-transformer>=0.19.1
+ - orjson~=3.10
+ - torch-geometric>=2.6.0 ; extra == 'graph'
+ - editdistance~=0.8.1 ; extra == 'nlp'
+ - rouge-score~=0.1.2 ; extra == 'nlp'
+ - nltk~=3.9.1 ; extra == 'nlp'
requires_python: '>=3.12,<3.14'
- pypi: https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl
name: pyparsing
@@ -7416,6 +7517,16 @@ packages:
- pkg:pypi/rich?source=compressed-mapping
size: 201098
timestamp: 1753436991345
+- pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz
+ name: rouge-score
+ version: 0.1.2
+ sha256: c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04
+ requires_dist:
+ - absl-py
+ - nltk
+ - numpy
+ - six>=1.14.0
+ requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl
name: s3transfer
version: 0.16.0
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 54e77670c..71b0d5b67 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -59,6 +59,8 @@ def __init__(self, *args, **kwargs):
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
+from .fhir_cehr import ConceptVocab
+from .mimic4_fhir import MIMIC4FHIRDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
diff --git a/pyhealth/datasets/configs/mimic4_fhir.yaml b/pyhealth/datasets/configs/mimic4_fhir.yaml
new file mode 100644
index 000000000..989ab79cf
--- /dev/null
+++ b/pyhealth/datasets/configs/mimic4_fhir.yaml
@@ -0,0 +1,118 @@
+# MIMIC-IV FHIR Resource Flattening Configuration
+# ================================================
+#
+# This YAML defines the normalized schema for MIMIC-IV FHIR exports after streaming
+# ingestion. Raw NDJSON/NDJSON.GZ resources are parsed and flattened into six
+# per-resource-type Parquet tables (see ``version`` below), then loaded through the
+# standard BaseDataset pipeline for task construction.
+#
+# For ingest details, see the docstring of ``stream_fhir_ndjson_to_flat_tables()``
+# in ``pyhealth/datasets/mimic4_fhir.py``.
+
+version: "fhir_r4_flattened"
+
+# Glob Pattern(s) for NDJSON File Discovery
+# ===========================================
+#
+# ``glob_patterns`` (list) or ``glob_pattern`` (string): Patterns to match NDJSON files
+# under the ingest root directory. Patterns are applied via pathlib.Path.glob().
+#
+# Default: Six targeted patterns matching PhysioNet MIMIC-IV FHIR Mimic* shard families
+# that map to flattened tables. This avoids decompressing and parsing ~10% of PhysioNet
+# exports (MedicationAdministration, Specimen, Organization, …) that are skipped by
+# the flattener.
+#
+# Alternatives:
+# - For non-PhysioNet naming, use a single broad pattern:
+# glob_pattern: "**/*.ndjson.gz"
+# - To test on a subset, use a narrower list:
+# glob_patterns:
+# - "**/MimicPatient*.ndjson.gz"
+# - "**/MimicObservation*.ndjson.gz"
+#
+# Notes:
+# - Patterns use ``**/`` for recursive search (works in both flat and nested layouts).
+# - Can be overridden at runtime via MIMIC4FHIRDataset(glob_pattern=...) or
+# MIMIC4FHIRDataset(glob_patterns=[...]).
+
+glob_patterns:
+ - "**/MimicPatient*.ndjson.gz"
+ - "**/MimicEncounter*.ndjson.gz"
+ - "**/MimicCondition*.ndjson.gz"
+ - "**/MimicObservation*.ndjson.gz"
+ - "**/MimicMedicationRequest*.ndjson.gz"
+ - "**/MimicProcedure*.ndjson.gz"
+
+# Flattened Table Schema
+# ======================
+#
+# Each table is normalized from a single FHIR resource type. Columns are:
+# - patient_id (str): Foreign key to patient (derived from subject.reference or id).
+# - [timestamp] (str): ISO 8601 datetime string (coerced; nullable).
+# - attributes (List[str]): Additional columns from the resource.
+#
+# Unsupported resource types (Medication, MedicationAdministration, Specimen, …)
+# are silently dropped during ingest; only tables listed here are written.
+
+tables:
+ patient:
+ file_path: "patient.parquet"
+ patient_id: "patient_id"
+ timestamp: "birth_date"
+ attributes:
+ - "patient_fhir_id"
+ - "birth_date"
+ - "gender"
+ - "deceased_boolean"
+ - "deceased_datetime"
+
+ encounter:
+ file_path: "encounter.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "encounter_class"
+ - "encounter_end"
+
+ condition:
+ file_path: "condition.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ observation:
+ file_path: "observation.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ medication_request:
+ file_path: "medication_request.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ procedure:
+ file_path: "procedure.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
diff --git a/pyhealth/datasets/fhir_cehr.py b/pyhealth/datasets/fhir_cehr.py
new file mode 100644
index 000000000..67ae35c5e
--- /dev/null
+++ b/pyhealth/datasets/fhir_cehr.py
@@ -0,0 +1,364 @@
+"""CEHR-style tokenization, vocabulary, and sequence building for FHIR timelines.
+
+Key public API
+--------------
+ConceptVocab
+ Token-to-dense-id mapping with PAD/UNK reserved at 0 and 1. JSON-serializable.
+
+build_cehr_sequences(patient, vocab, max_len)
+ Flatten a Patient's tabular FHIR rows into CEHR-aligned feature lists.
+
+infer_mortality_label(patient)
+ Heuristic binary mortality label from flattened patient rows.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import orjson
+
+from pyhealth.data import Patient
+
+from .fhir_ingest import as_naive, parse_dt
+
+DEFAULT_PAD = 0
+DEFAULT_UNK = 1
+
+__all__ = [
+ # Constants
+ "DEFAULT_PAD",
+ "DEFAULT_UNK",
+ "EVENT_TYPE_TO_TOKEN_TYPE",
+ # Vocabulary
+ "ConceptVocab",
+ "ensure_special_tokens",
+ # Sequence building
+ "collect_cehr_timeline_events",
+ "warm_mpf_vocab_from_patient",
+ "build_cehr_sequences",
+ # Labels
+ "infer_mortality_label",
+]
+
+EVENT_TYPE_TO_TOKEN_TYPE = {
+ "encounter": 1,
+ "condition": 2,
+ "medication_request": 3,
+ "observation": 4,
+ "procedure": 5,
+}
+
+# Table-driven lookups for flattened event-row column access.
+_CONCEPT_KEY_COL: Dict[str, str] = {
+ "condition": "condition/concept_key",
+ "observation": "observation/concept_key",
+ "medication_request": "medication_request/concept_key",
+ "procedure": "procedure/concept_key",
+}
+
+_ENCOUNTER_ID_COL: Dict[str, str] = {
+ "condition": "condition/encounter_id",
+ "observation": "observation/encounter_id",
+ "medication_request": "medication_request/encounter_id",
+ "procedure": "procedure/encounter_id",
+ "encounter": "encounter/encounter_id",
+}
+
+# ---------------------------------------------------------------------------
+# ConceptVocab
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ConceptVocab:
+ """Maps concept keys to dense ids with PAD/UNK reserved at 0 and 1."""
+
+ token_to_id: Dict[str, int] = field(default_factory=dict)
+ pad_id: int = DEFAULT_PAD
+ unk_id: int = DEFAULT_UNK
+ _next_id: int = 2
+
+ def __post_init__(self) -> None:
+ if not self.token_to_id:
+ self.token_to_id = {"": self.pad_id, "": self.unk_id}
+ self._next_id = 2
+
+ def add_token(self, key: str) -> int:
+ if key in self.token_to_id:
+ return self.token_to_id[key]
+ tid = self._next_id
+ self.token_to_id[key] = tid
+ self._next_id += 1
+ return tid
+
+ def __getitem__(self, key: str) -> int:
+ return self.token_to_id.get(key, self.unk_id)
+
+ @property
+ def vocab_size(self) -> int:
+ return self._next_id
+
+ def to_json(self) -> Dict[str, Any]:
+ return {"token_to_id": self.token_to_id, "next_id": self._next_id}
+
+ @classmethod
+ def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab":
+ vocab = cls()
+ loaded = dict(data.get("token_to_id") or {})
+ if not loaded:
+ vocab._next_id = int(data.get("next_id", 2))
+ return vocab
+ vocab.token_to_id = loaded
+ vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1))
+ return vocab
+
+ def save(self, path: str) -> None:
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+ Path(path).write_bytes(orjson.dumps(self.to_json(), option=orjson.OPT_SORT_KEYS))
+
+ @classmethod
+ def load(cls, path: str) -> "ConceptVocab":
+ return cls.from_json(orjson.loads(Path(path).read_bytes()))
+
+
+def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]:
+ """Add EHRMamba/CEHR special tokens and return their ids."""
+ return {name: vocab.add_token(name) for name in ("", "", "", "")}
+
+
+# ---------------------------------------------------------------------------
+# Row utilities for flattened event stream
+# ---------------------------------------------------------------------------
+
+
+def _clean_string(value: Any) -> Optional[str]:
+ if value is None:
+ return None
+ if isinstance(value, str):
+ return value.strip() or None
+ return str(value)
+
+
+def _deceased_boolean_column_means_dead(value: Any) -> bool:
+ """True only for an explicit affirmative stored flag (not Python truthiness)."""
+ s = _clean_string(value)
+ return s is not None and s.lower() == "true"
+
+
+def _row_datetime(value: Any) -> Optional[datetime]:
+ if value is None:
+ return None
+ if isinstance(value, datetime):
+ return as_naive(value)
+ try:
+ return parse_dt(str(value))
+ except Exception:
+ return None
+
+
+def _concept_key_from_row(row: Dict[str, Any]) -> str:
+ event_type = row.get("event_type")
+ col = _CONCEPT_KEY_COL.get(event_type)
+ if col:
+ return _clean_string(row.get(col)) or f"{event_type}|unknown"
+ if event_type == "encounter":
+ enc_class = _clean_string(row.get("encounter/encounter_class"))
+ return f"encounter|{enc_class}" if enc_class else "encounter|unknown"
+ return f"{event_type or 'event'}|unknown"
+
+
+def _linked_encounter_id_from_row(row: Dict[str, Any]) -> Optional[str]:
+ col = _ENCOUNTER_ID_COL.get(row.get("event_type"))
+ return _clean_string(row.get(col)) if col else None
+
+
+def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]:
+ for row in patient.data_source.iter_rows(named=True):
+ if row.get("event_type") != "patient":
+ continue
+ birth = _row_datetime(row.get("timestamp"))
+ if birth is not None:
+ return birth
+ raw = _clean_string(row.get("patient/birth_date"))
+ if raw:
+ return parse_dt(raw)
+ return None
+
+
+def _sequential_visit_idx_for_time(
+ event_time: Optional[datetime],
+ visit_encounters: List[Tuple[datetime, int]],
+) -> int:
+ if not visit_encounters:
+ return 0
+ if event_time is None:
+ return visit_encounters[-1][1]
+ event_time = as_naive(event_time)
+ chosen = visit_encounters[0][1]
+ for encounter_start, visit_idx in visit_encounters:
+ if encounter_start <= event_time:
+ chosen = visit_idx
+ else:
+ break
+ return chosen
+
+
+# ---------------------------------------------------------------------------
+# CEHR timeline and sequence building
+# ---------------------------------------------------------------------------
+
+
+def collect_cehr_timeline_events(
+ patient: Patient,
+) -> List[Tuple[datetime, str, str, int]]:
+ """Collect (time, concept_key, event_type, visit_idx) tuples from a patient's rows."""
+ rows = list(
+ patient.data_source.sort(["timestamp", "event_type"], nulls_last=True).iter_rows(named=True)
+ )
+
+ encounter_rows: List[Tuple[datetime, str]] = []
+ for row in rows:
+ if row.get("event_type") != "encounter":
+ continue
+ enc_id = _linked_encounter_id_from_row(row)
+ enc_start = _row_datetime(row.get("timestamp"))
+ if enc_id is not None and enc_start is not None:
+ encounter_rows.append((enc_start, enc_id))
+
+ encounter_rows.sort(key=lambda pair: pair[0])
+ encounter_visit_idx = {enc_id: idx for idx, (_, enc_id) in enumerate(encounter_rows)}
+ encounter_start_by_id = {enc_id: enc_start for enc_start, enc_id in encounter_rows}
+ visit_encounters = [(enc_start, idx) for idx, (enc_start, _) in enumerate(encounter_rows)]
+
+ events: List[Tuple[datetime, str, str, int]] = []
+ unlinked: List[Tuple[Optional[datetime], str, str]] = []
+
+ for row in rows:
+ event_type = row.get("event_type")
+ if event_type not in EVENT_TYPE_TO_TOKEN_TYPE:
+ continue
+ event_time = _row_datetime(row.get("timestamp"))
+ concept_key = _concept_key_from_row(row)
+
+ if event_type == "encounter":
+ enc_id = _linked_encounter_id_from_row(row)
+ if enc_id is None or event_time is None:
+ continue
+ visit_idx = encounter_visit_idx.get(enc_id)
+ if visit_idx is None:
+ continue
+ events.append((event_time, concept_key, event_type, visit_idx))
+ continue
+
+ enc_id = _linked_encounter_id_from_row(row)
+ if enc_id and enc_id in encounter_visit_idx:
+ visit_idx = encounter_visit_idx[enc_id]
+ if event_time is None:
+ event_time = encounter_start_by_id.get(enc_id)
+ if event_time is None:
+ continue
+ events.append((event_time, concept_key, event_type, visit_idx))
+ else:
+ unlinked.append((event_time, concept_key, event_type))
+
+ for event_time, concept_key, event_type in unlinked:
+ visit_idx = _sequential_visit_idx_for_time(event_time, visit_encounters)
+ if event_time is None:
+ if not visit_encounters:
+ continue
+ for enc_start, enc_visit_idx in visit_encounters:
+ if enc_visit_idx == visit_idx:
+ event_time = enc_start
+ break
+ else:
+ event_time = visit_encounters[-1][0]
+ if event_time is None:
+ continue
+ events.append((event_time, concept_key, event_type, visit_idx))
+
+ events.sort(key=lambda item: item[0])
+ return events
+
+
+def warm_mpf_vocab_from_patient(
+ vocab: ConceptVocab,
+ patient: Patient,
+ clinical_cap: int,
+) -> None:
+ """Add concept keys from the last clinical_cap events of a patient to vocab."""
+ tail = collect_cehr_timeline_events(patient)[-clinical_cap:] if clinical_cap > 0 else []
+ for _, concept_key, _, _ in tail:
+ vocab.add_token(concept_key)
+
+
+def build_cehr_sequences(
+ patient: Patient,
+ vocab: ConceptVocab,
+ max_len: int,
+ *,
+ base_time: Optional[datetime] = None,
+ grow_vocab: bool = True,
+) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]:
+ """Flatten a patient's tabular FHIR rows into CEHR-aligned feature lists."""
+ events = collect_cehr_timeline_events(patient)
+ birth = _birth_datetime_from_patient(patient)
+
+ if base_time is None:
+ base_time = events[0][0] if events else datetime.now()
+ base_time = as_naive(base_time)
+ birth = as_naive(birth)
+
+ concept_ids: List[int] = []
+ token_types: List[int] = []
+ time_stamps: List[float] = []
+ ages: List[float] = []
+ visit_orders: List[int] = []
+ visit_segments: List[int] = []
+
+ for event_time, concept_key, event_type, visit_idx in (events[-max_len:] if max_len > 0 else []):
+ event_time = as_naive(event_time)
+ concept_id = vocab.add_token(concept_key) if grow_vocab else vocab[concept_key]
+ token_type = EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0)
+ time_delta = (
+ float((event_time - base_time).total_seconds())
+ if base_time is not None and event_time is not None
+ else 0.0
+ )
+ age_years = (
+ (event_time - birth).days / 365.25
+ if birth is not None and event_time is not None
+ else 0.0
+ )
+ concept_ids.append(concept_id)
+ token_types.append(token_type)
+ time_stamps.append(time_delta)
+ ages.append(age_years)
+ visit_orders.append(min(visit_idx, 511))
+ visit_segments.append(visit_idx % 2)
+
+ return concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments
+
+
+# ---------------------------------------------------------------------------
+# Label inference
+# ---------------------------------------------------------------------------
+
+
+def infer_mortality_label(patient: Patient) -> int:
+ """Heuristic binary mortality label from flattened patient rows."""
+ for row in patient.data_source.iter_rows(named=True):
+ if row.get("event_type") == "patient":
+ if _deceased_boolean_column_means_dead(row.get("patient/deceased_boolean")):
+ return 1
+ if _clean_string(row.get("patient/deceased_datetime")):
+ return 1
+ for row in patient.data_source.iter_rows(named=True):
+ if row.get("event_type") == "condition":
+ key = (_clean_string(row.get("condition/concept_key")) or "").lower()
+ if any(token in key for token in ("death", "deceased", "mortality")):
+ return 1
+ return 0
diff --git a/pyhealth/datasets/fhir_ingest.py b/pyhealth/datasets/fhir_ingest.py
new file mode 100644
index 000000000..e4926e58d
--- /dev/null
+++ b/pyhealth/datasets/fhir_ingest.py
@@ -0,0 +1,500 @@
+"""FHIR NDJSON parsing, flattening, and Parquet table writing.
+
+Key public API
+--------------
+stream_fhir_ndjson_to_flat_tables(root, glob_pattern, out_dir)
+ Stream all matching NDJSON/NDJSON.GZ resources into six per-type Parquet tables.
+
+sorted_ndjson_files(root, glob_pattern)
+ List matching NDJSON files under root (deduplicated, sorted).
+
+filter_flat_tables_by_patient_ids(source_dir, out_dir, keep_ids)
+ Subset existing flattened tables to a specific patient cohort.
+
+synthetic_mpf_two_patient_ndjson_text()
+ In-memory two-patient fixture for tests and quick-test mode.
+"""
+
+from __future__ import annotations
+
+import gzip
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
+
+import orjson
+import polars as pl
+import pyarrow as pa
+import pyarrow.parquet as pq
+
+GlobPatternArg = str | Sequence[str]
+"""Single glob string or sequence of strings for NDJSON file discovery."""
+
+__all__ = [
+ # Types
+ "GlobPatternArg",
+ # Constants
+ "FHIR_SCHEMA_VERSION",
+ "FHIR_TABLES",
+ "FHIR_TABLES_FOR_PATIENT_IDS",
+ "FHIR_TABLE_FILE_NAMES",
+ "FHIR_TABLE_COLUMNS",
+ # Datetime helpers (also used by fhir_cehr)
+ "parse_dt",
+ "as_naive",
+ # FHIR iteration
+ "iter_ndjson_objects",
+ "iter_resources_from_ndjson_obj",
+ # Resource extraction
+ "patient_id_for_resource",
+ # Pipeline
+ "sorted_ndjson_files",
+ "stream_fhir_ndjson_to_flat_tables",
+ "filter_flat_tables_by_patient_ids",
+ # Synthetic fixtures
+ "synthetic_mpf_one_patient_resources",
+ "synthetic_mpf_two_patient_resources",
+ "synthetic_mpf_one_patient_ndjson_text",
+ "synthetic_mpf_two_patient_ndjson_text",
+]
+
+FHIR_SCHEMA_VERSION = 3
+
+FHIR_TABLES: List[str] = [
+ "patient",
+ "encounter",
+ "condition",
+ "observation",
+ "medication_request",
+ "procedure",
+]
+
+FHIR_TABLES_FOR_PATIENT_IDS: List[str] = [t for t in FHIR_TABLES if t != "patient"]
+
+FHIR_TABLE_FILE_NAMES: Dict[str, str] = {t: f"{t}.parquet" for t in FHIR_TABLES}
+
+FHIR_TABLE_COLUMNS: Dict[str, List[str]] = {
+ "patient": ["patient_id", "patient_fhir_id", "birth_date", "gender", "deceased_boolean", "deceased_datetime"],
+ "encounter": ["patient_id", "resource_id", "encounter_id", "event_time", "encounter_class", "encounter_end"],
+ "condition": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+ "observation": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+ "medication_request": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+ "procedure": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+}
+
+# ---------------------------------------------------------------------------
+# Datetime helpers (also imported by fhir_cehr)
+# ---------------------------------------------------------------------------
+
+
+def parse_dt(s: Optional[str]) -> Optional[datetime]:
+ if not s:
+ return None
+ try:
+ dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
+ except ValueError:
+ dt = None
+ if dt is None and len(s) >= 10:
+ try:
+ dt = datetime.strptime(s[:10], "%Y-%m-%d")
+ except ValueError:
+ return None
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+def as_naive(dt: Optional[datetime]) -> Optional[datetime]:
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+# ---------------------------------------------------------------------------
+# FHIR JSON helpers
+# ---------------------------------------------------------------------------
+
+
+def _coding_key(coding: Dict[str, Any]) -> str:
+ return f"{coding.get('system') or 'unknown'}|{coding.get('code') or 'unknown'}"
+
+
+def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]:
+ if not obj:
+ return None
+ codings = obj.get("coding") or []
+ if not codings and "concept" in obj:
+ codings = (obj.get("concept") or {}).get("coding") or []
+ return _coding_key(codings[0]) if codings else None
+
+
+def _ref_id(ref: Optional[str]) -> Optional[str]:
+ if not ref:
+ return None
+ return ref.rsplit("/", 1)[-1] if "/" in ref else ref
+
+
+def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]:
+ if not isinstance(raw, dict):
+ return None
+ resource = raw.get("resource") if "resource" in raw else raw
+ return resource if isinstance(resource, dict) else None
+
+
+def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
+ """Yield resource dicts from one parsed NDJSON object (Bundle or bare resource)."""
+ if "entry" in obj:
+ for entry in obj.get("entry") or []:
+ resource = entry.get("resource")
+ if isinstance(resource, dict):
+ yield resource
+ return
+ resource = _unwrap_resource_dict(obj)
+ if resource is not None:
+ yield resource
+
+
+def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]:
+ """Yield parsed JSON objects from a plain or gzip-compressed NDJSON file."""
+ opener = (
+ gzip.open(path, "rt", encoding="utf-8", errors="replace")
+ if path.suffix == ".gz"
+ else open(path, encoding="utf-8", errors="replace")
+ )
+ with opener as stream:
+ for line in stream:
+ line = line.strip()
+ if not line:
+ continue
+ parsed = orjson.loads(line)
+ if isinstance(parsed, dict):
+ yield parsed
+
+
+# ---------------------------------------------------------------------------
+# Resource field extraction
+# ---------------------------------------------------------------------------
+
+
+def _clinical_concept_key(res: Dict[str, Any]) -> Optional[str]:
+ """Resolve a stable token key from a FHIR resource."""
+ resource_type = res.get("resourceType")
+ if resource_type == "MedicationRequest":
+ med_cc = res.get("medicationCodeableConcept")
+ if isinstance(med_cc, dict):
+ key = _first_coding(med_cc)
+ if key:
+ return key
+ med_ref = res.get("medicationReference")
+ if isinstance(med_ref, dict):
+ ref = med_ref.get("reference")
+ if ref:
+ return f"MedicationRequest/reference|{_ref_id(ref) or ref}"
+ return None
+ code = res.get("code")
+ return _first_coding(code) if isinstance(code, dict) else None
+
+
+def patient_id_for_resource(
+ resource: Dict[str, Any],
+ resource_type: Optional[str] = None,
+) -> Optional[str]:
+ resource_type = resource_type or resource.get("resourceType")
+ if resource_type == "Patient":
+ pid = resource.get("id")
+ return str(pid) if pid is not None else None
+ if resource_type in {"Encounter", "Condition", "Observation", "MedicationRequest", "Procedure"}:
+ return _ref_id((resource.get("subject") or {}).get("reference"))
+ return None
+
+
+def _resource_time_string(
+ resource: Dict[str, Any],
+ resource_type: Optional[str] = None,
+) -> Optional[str]:
+ resource_type = resource_type or resource.get("resourceType")
+ if resource_type == "Patient":
+ return resource.get("birthDate")
+ if resource_type == "Encounter":
+ return (resource.get("period") or {}).get("start")
+ if resource_type == "Condition":
+ return resource.get("onsetDateTime") or resource.get("recordedDate")
+ if resource_type == "Observation":
+ return resource.get("effectiveDateTime") or resource.get("issued")
+ if resource_type == "MedicationRequest":
+ return resource.get("authoredOn")
+ if resource_type == "Procedure":
+ return resource.get("performedDateTime") or resource.get("recordedDate")
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Flattening
+# ---------------------------------------------------------------------------
+
+
+def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]:
+ """Map Patient.deceasedBoolean to stored "true"/"false"/None.
+
+ FHIR JSON uses real booleans; some exports use strings. Python's
+ bool("false") is True, so we must not coerce with bool().
+ """
+ if value is None:
+ return None
+ if value is True:
+ return "true"
+ if value is False:
+ return "false"
+ if isinstance(value, str):
+ key = value.strip().lower()
+ if key in ("true", "1", "yes", "y", "t"):
+ return "true"
+ if key in ("false", "0", "no", "n", "f", ""):
+ return "false"
+ return None
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
+ if value == 0:
+ return "false"
+ if value == 1:
+ return "true"
+ return None
+ return None
+
+
+_RESOURCE_TYPE_TO_TABLE: Dict[str, str] = {
+ "Condition": "condition",
+ "Observation": "observation",
+ "MedicationRequest": "medication_request",
+ "Procedure": "procedure",
+}
+
+
+def _flatten_resource_to_table_row(
+ resource: Dict[str, Any],
+) -> Optional[Tuple[str, Dict[str, Optional[str]]]]:
+ """Map one FHIR resource dict to (table_name, row_dict), or None if unsupported."""
+ resource_type = resource.get("resourceType")
+ patient_id = patient_id_for_resource(resource, resource_type)
+ if not patient_id:
+ return None
+
+ if resource_type == "Patient":
+ return "patient", {
+ "patient_id": patient_id,
+ "patient_fhir_id": str(resource.get("id") or patient_id),
+ "birth_date": resource.get("birthDate"),
+ "gender": resource.get("gender"),
+ "deceased_boolean": _normalize_deceased_boolean_for_storage(resource.get("deceasedBoolean")),
+ "deceased_datetime": resource.get("deceasedDateTime"),
+ }
+
+ resource_id = str(resource.get("id")) if resource.get("id") is not None else None
+ event_time = _resource_time_string(resource, resource_type)
+
+ if resource_type == "Encounter":
+ return "encounter", {
+ "patient_id": patient_id,
+ "resource_id": resource_id,
+ "encounter_id": resource_id,
+ "event_time": event_time,
+ "encounter_class": (resource.get("class") or {}).get("code"),
+ "encounter_end": (resource.get("period") or {}).get("end"),
+ }
+
+ table_name = _RESOURCE_TYPE_TO_TABLE.get(resource_type)
+ if table_name is None:
+ return None
+ return table_name, {
+ "patient_id": patient_id,
+ "resource_id": resource_id,
+ "encounter_id": _ref_id((resource.get("encounter") or {}).get("reference")),
+ "event_time": event_time,
+ "concept_key": _clinical_concept_key(resource),
+ }
+
+
+# ---------------------------------------------------------------------------
+# Parquet writer
+# ---------------------------------------------------------------------------
+
+
+def _table_schema(table_name: str) -> pa.Schema:
+ return pa.schema([(col, pa.string()) for col in FHIR_TABLE_COLUMNS[table_name]])
+
+
+class _BufferedParquetWriter:
+ def __init__(self, path: Path, schema: pa.Schema, batch_size: int = 50_000) -> None:
+ self.path = path
+ self.schema = schema
+ self.batch_size = batch_size
+ self.rows: List[Dict[str, Any]] = []
+ self.writer: Optional[pq.ParquetWriter] = None
+ self.path.parent.mkdir(parents=True, exist_ok=True)
+
+ def add(self, row: Dict[str, Any]) -> None:
+ self.rows.append(row)
+ if len(self.rows) >= self.batch_size:
+ self.flush()
+
+ def flush(self) -> None:
+ if not self.rows:
+ return
+ table = pa.Table.from_pylist(self.rows, schema=self.schema)
+ if self.writer is None:
+ self.writer = pq.ParquetWriter(str(self.path), self.schema)
+ self.writer.write_table(table)
+ self.rows.clear()
+
+ def close(self) -> None:
+ self.flush()
+ if self.writer is None:
+ pq.write_table(pa.Table.from_pylist([], schema=self.schema), str(self.path))
+ return
+ self.writer.close()
+
+
+# ---------------------------------------------------------------------------
+# Pipeline
+# ---------------------------------------------------------------------------
+
+
+def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]:
+ """Return sorted unique file paths under root matching glob pattern(s).
+
+ Args:
+ root: Root directory to search under.
+ glob_pattern: Single glob string or sequence of glob strings.
+
+ Returns:
+ Sorted list of matching files. Empty if no matches.
+ """
+ patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern)
+ files: set[Path] = set()
+ for pat in patterns:
+ files.update(p for p in root.glob(pat) if p.is_file())
+ return sorted(files, key=lambda p: str(p))
+
+
+def stream_fhir_ndjson_to_flat_tables(
+ root: Path,
+ glob_pattern: GlobPatternArg,
+ out_dir: Path,
+) -> None:
+ """Stream NDJSON resources into normalized per-resource Parquet tables under out_dir.
+
+ Args:
+ root: Root directory containing NDJSON/NDJSON.GZ files.
+ glob_pattern: Single glob string or sequence of glob strings.
+ out_dir: Output directory for per-resource-type Parquet tables.
+ Creates patient.parquet, encounter.parquet, condition.parquet,
+ observation.parquet, medication_request.parquet, procedure.parquet.
+ """
+ out_dir.mkdir(parents=True, exist_ok=True)
+ writers = {
+ name: _BufferedParquetWriter(path=out_dir / FHIR_TABLE_FILE_NAMES[name], schema=_table_schema(name))
+ for name in FHIR_TABLES
+ }
+ try:
+ for file_path in sorted_ndjson_files(root, glob_pattern):
+ for ndjson_obj in iter_ndjson_objects(file_path):
+ for resource in iter_resources_from_ndjson_obj(ndjson_obj):
+ result = _flatten_resource_to_table_row(resource)
+ if result is not None:
+ writers[result[0]].add(result[1])
+ finally:
+ for writer in writers.values():
+ writer.close()
+
+
+def _sorted_patient_ids_from_flat_tables(table_dir: Path) -> List[str]:
+ patient_table = table_dir / FHIR_TABLE_FILE_NAMES["patient"]
+ if patient_table.exists():
+ return (
+ pl.scan_parquet(str(patient_table))
+ .select("patient_id")
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+ frames = [
+ pl.scan_parquet(str(table_dir / FHIR_TABLE_FILE_NAMES[t])).select("patient_id")
+ for t in FHIR_TABLES_FOR_PATIENT_IDS
+ ]
+ return (
+ pl.concat(frames)
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+
+
+def filter_flat_tables_by_patient_ids(
+ source_dir: Path,
+ out_dir: Path,
+ keep_ids: Sequence[str],
+) -> None:
+ """Filter all flattened tables to only include rows for the given patient IDs."""
+ out_dir.mkdir(parents=True, exist_ok=True)
+ keep_set = set(keep_ids)
+ for name in FHIR_TABLES:
+ src = source_dir / FHIR_TABLE_FILE_NAMES[name]
+ dst = out_dir / FHIR_TABLE_FILE_NAMES[name]
+ pl.scan_parquet(str(src)).filter(pl.col("patient_id").is_in(keep_set)).sink_parquet(str(dst))
+
+
+# ---------------------------------------------------------------------------
+# Synthetic fixtures (for tests and --quick-test mode)
+# ---------------------------------------------------------------------------
+
+
+def synthetic_mpf_one_patient_resources() -> List[Dict[str, Any]]:
+ return [
+ {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"},
+ {
+ "resourceType": "Encounter",
+ "id": "e1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "period": {"start": "2020-06-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Condition",
+ "id": "c1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "encounter": {"reference": "Encounter/e1"},
+ "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]},
+ "onsetDateTime": "2020-06-01T11:00:00Z",
+ },
+ ]
+
+
+def synthetic_mpf_two_patient_resources() -> List[Dict[str, Any]]:
+ return [
+ *synthetic_mpf_one_patient_resources(),
+ {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True},
+ {
+ "resourceType": "Encounter",
+ "id": "e-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "period": {"start": "2020-07-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation",
+ "id": "o-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "encounter": {"reference": "Encounter/e-dead"},
+ "effectiveDateTime": "2020-07-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]},
+ },
+ ]
+
+
+def synthetic_mpf_one_patient_ndjson_text() -> str:
+ return "\n".join(orjson.dumps(r).decode("utf-8") for r in synthetic_mpf_one_patient_resources()) + "\n"
+
+
+def synthetic_mpf_two_patient_ndjson_text() -> str:
+ return "\n".join(orjson.dumps(r).decode("utf-8") for r in synthetic_mpf_two_patient_resources()) + "\n"
diff --git a/pyhealth/datasets/mimic4_fhir.py b/pyhealth/datasets/mimic4_fhir.py
new file mode 100644
index 000000000..d9f8e428a
--- /dev/null
+++ b/pyhealth/datasets/mimic4_fhir.py
@@ -0,0 +1,369 @@
+"""MIMIC-IV FHIR ingestion using flattened resource tables.
+
+Architecture
+------------
+1. Stream NDJSON/NDJSON.GZ FHIR resources from disk.
+2. Normalize each resource type into a 2D table (Patient, Encounter, Condition,
+ Observation, MedicationRequest, Procedure) via :mod:`~pyhealth.datasets.fhir_ingest`.
+3. Feed those tables through the standard YAML-driven
+ :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task processing
+ operates on :class:`~pyhealth.data.Patient` and ``global_event_df`` rows.
+
+Module layout
+-------------
+- :mod:`~pyhealth.datasets.fhir_ingest` -- NDJSON → Parquet pipeline.
+- :mod:`~pyhealth.datasets.fhir_cehr` -- CEHR tokenization, ConceptVocab, labels.
+- This module -- :class:`MIMIC4FHIRDataset` wiring ingest
+ into :class:`~pyhealth.datasets.BaseDataset`.
+"""
+
+from __future__ import annotations
+
+import functools
+import hashlib
+import itertools
+import logging
+import operator
+import os
+import shutil
+import uuid
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence
+
+import dask.dataframe as dd
+import narwhals as nw
+import orjson
+import pandas as pd
+import platformdirs
+import polars as pl
+from litdata.processing.data_processor import in_notebook
+from yaml import safe_load
+
+from ..data import Patient
+from .base_dataset import BaseDataset
+from .fhir_cehr import ConceptVocab, ensure_special_tokens, warm_mpf_vocab_from_patient
+from .fhir_ingest import (
+ FHIR_SCHEMA_VERSION,
+ FHIR_TABLE_FILE_NAMES,
+ FHIR_TABLES,
+ _sorted_patient_ids_from_flat_tables,
+ filter_flat_tables_by_patient_ids,
+ stream_fhir_ndjson_to_flat_tables,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def read_fhir_settings_yaml(path: Optional[str] = None) -> Dict[str, Any]:
+ if path is None:
+ path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_fhir.yaml")
+ with open(path, encoding="utf-8") as stream:
+ data = safe_load(stream)
+ return data if isinstance(data, dict) else {}
+
+
+def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series:
+ if getattr(part.dtype, "tz", None) is not None:
+ part = part.dt.tz_localize(None)
+ return part.astype("datetime64[ms]")
+
+
+class MIMIC4FHIRDataset(BaseDataset):
+ """MIMIC-IV FHIR with flattened resource tables and standard PyHealth task flow.
+
+ Streams raw MIMIC-IV FHIR NDJSON/NDJSON.GZ exports into six flattened Parquet
+ tables then pipelines them through :class:`~pyhealth.datasets.BaseDataset` for
+ standard downstream task processing (global event dataframe, patient iteration,
+ task sampling).
+
+ **Ingest flow (out-of-core):**
+
+ 1. Scan NDJSON files matching ``glob_patterns`` (defaults to six Mimic* families).
+ 2. Parse and flatten each FHIR resource into a row in the appropriate table.
+ 3. Cache normalized tables under ``cache_dir / {uuid} / flattened_tables/``.
+ 4. Load tables into ``global_event_df`` via YAML config.
+
+ **Cache fingerprinting:** includes ``glob_patterns`` and YAML digest, so changes
+ to either create a new independent cache.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ config_path: Optional[str] = None,
+ glob_pattern: Optional[str] = None,
+ glob_patterns: Optional[Sequence[str]] = None,
+ max_patients: Optional[int] = None,
+ ingest_num_shards: Optional[int] = None,
+ vocab_path: Optional[str] = None,
+ cache_dir: Optional[str | Path] = None,
+ num_workers: int = 1,
+ dev: bool = False,
+ ) -> None:
+ """Initialize a MIMIC-IV FHIR dataset.
+
+ Args:
+ root: Path to the NDJSON/NDJSON.GZ export directory.
+ config_path: Path to a custom YAML config. Defaults to
+ ``pyhealth/datasets/configs/mimic4_fhir.yaml``.
+ glob_pattern: Single glob for NDJSON files. Mutually exclusive
+ with ``glob_patterns``.
+ glob_patterns: Multiple glob patterns. Mutually exclusive with
+ ``glob_pattern``.
+ max_patients: Limit ingest to the first N unique patient IDs.
+ ingest_num_shards: Ignored; retained for API compatibility.
+ vocab_path: Path to a pre-built ConceptVocab JSON file.
+ cache_dir: Cache directory root (UUID subdir appended per config).
+ num_workers: Worker processes for task sampling.
+ dev: Development mode; limits to 1000 patients if max_patients is None.
+ """
+ del ingest_num_shards
+
+ default_cfg = os.path.join(os.path.dirname(__file__), "configs", "mimic4_fhir.yaml")
+ self._fhir_config_path = str(Path(config_path or default_cfg).resolve())
+ self._fhir_settings = read_fhir_settings_yaml(self._fhir_config_path)
+
+ if glob_pattern is not None and glob_patterns is not None:
+ raise ValueError("Pass at most one of glob_pattern and glob_patterns.")
+ if glob_patterns is not None:
+ self.glob_patterns: List[str] = list(glob_patterns)
+ elif glob_pattern is not None:
+ self.glob_patterns = [glob_pattern]
+ else:
+ raw_list = self._fhir_settings.get("glob_patterns")
+ if raw_list:
+ if not isinstance(raw_list, list):
+ raise TypeError("mimic4_fhir.yaml glob_patterns must be a list of strings.")
+ self.glob_patterns = [str(x) for x in raw_list]
+ elif self._fhir_settings.get("glob_pattern") is not None:
+ self.glob_patterns = [str(self._fhir_settings["glob_pattern"])]
+ else:
+ self.glob_patterns = ["**/*.ndjson.gz"]
+
+ self.glob_pattern = (
+ self.glob_patterns[0] if len(self.glob_patterns) == 1 else "; ".join(self.glob_patterns)
+ )
+ self.max_patients = 1000 if dev and max_patients is None else max_patients
+ self.source_root = str(Path(root).expanduser().resolve())
+ self.vocab = (
+ ConceptVocab.load(vocab_path)
+ if vocab_path and os.path.isfile(vocab_path)
+ else ConceptVocab()
+ )
+ super().__init__(
+ root=self.source_root,
+ tables=FHIR_TABLES,
+ dataset_name="mimic4_fhir",
+ config_path=self._fhir_config_path,
+ cache_dir=cache_dir,
+ num_workers=num_workers,
+ dev=dev,
+ )
+
+ def _init_cache_dir(self, cache_dir: str | Path | None) -> Path:
+ try:
+ yaml_digest = hashlib.sha256(Path(self._fhir_config_path).read_bytes()).hexdigest()[:16]
+ except OSError:
+ yaml_digest = "missing"
+ identity = orjson.dumps(
+ {
+ "source_root": self.source_root,
+ "tables": sorted(self.tables),
+ "dataset_name": self.dataset_name,
+ "dev": self.dev,
+ "glob_patterns": self.glob_patterns,
+ "max_patients": self.max_patients,
+ "fhir_schema_version": FHIR_SCHEMA_VERSION,
+ "fhir_yaml_digest16": yaml_digest,
+ },
+ option=orjson.OPT_SORT_KEYS,
+ ).decode("utf-8")
+ cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity))
+ out = (
+ Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_id
+ if cache_dir is None
+ else Path(cache_dir) / cache_id
+ )
+ out.mkdir(parents=True, exist_ok=True)
+ logger.info("Cache dir: %s", out)
+ return out
+
+ @property
+ def prepared_tables_dir(self) -> Path:
+ return self.cache_dir / "flattened_tables"
+
+ def _ensure_prepared_tables(self) -> None:
+ root = Path(self.source_root)
+ if not root.is_dir():
+ raise FileNotFoundError(f"MIMIC4 FHIR root not found: {root}")
+
+ expected = [self.prepared_tables_dir / FHIR_TABLE_FILE_NAMES[t] for t in FHIR_TABLES]
+ if all(p.is_file() for p in expected):
+ return
+ if self.prepared_tables_dir.exists():
+ shutil.rmtree(self.prepared_tables_dir)
+
+ try:
+ staging_root = self.create_tmpdir()
+ staging = staging_root / "flattened_fhir_tables"
+ staging.mkdir(parents=True, exist_ok=True)
+ stream_fhir_ndjson_to_flat_tables(root, self.glob_patterns, staging)
+
+ if self.max_patients is None:
+ shutil.move(str(staging), str(self.prepared_tables_dir))
+ return
+
+ filtered_root = self.create_tmpdir()
+ filtered = filtered_root / "flattened_fhir_tables_filtered"
+ patient_ids = _sorted_patient_ids_from_flat_tables(staging)
+ filter_flat_tables_by_patient_ids(staging, filtered, patient_ids[: self.max_patients])
+ shutil.move(str(filtered), str(self.prepared_tables_dir))
+ finally:
+ self.clean_tmpdir()
+
+ def _event_transform(self, output_dir: Path) -> None:
+ self._ensure_prepared_tables()
+ super()._event_transform(output_dir)
+
+ def load_table(self, table_name: str) -> dd.DataFrame:
+ """Load one flattened Parquet table into the standard event schema.
+
+ Deviations from BaseDataset.load_table (which reads CSV via _scan_csv_tsv_gz):
+ - Reads from pre-built Parquet under prepared_tables_dir.
+ - Timestamp parsing uses errors="coerce" + utc=True (FHIR ISO strings
+ include timezone suffix or are partial dates).
+ - Strips tz-aware timestamps to naive UTC for Dask compatibility.
+ - Drops rows with null patient_id before returning.
+ """
+ assert self.config is not None
+ if table_name not in self.config.tables:
+ raise ValueError(f"Table {table_name} not found in config")
+
+ table_cfg = self.config.tables[table_name]
+ path = self.prepared_tables_dir / table_cfg.file_path
+ if not path.exists():
+ raise FileNotFoundError(f"Flattened table not found: {path}")
+
+ logger.info("Scanning FHIR flattened table: %s from %s", table_name, path)
+ df: dd.DataFrame = dd.read_parquet(
+ str(path), split_row_groups=True, blocksize="64MB"
+ ).replace("", pd.NA)
+ df = df.rename(columns=str.lower)
+
+ preprocess_func = getattr(self, f"preprocess_{table_name}", None)
+ if preprocess_func is not None:
+ logger.info("Preprocessing FHIR table: %s with %s", table_name, preprocess_func.__name__)
+ df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr]
+
+ for join_cfg in table_cfg.join:
+ join_path = self.prepared_tables_dir / Path(join_cfg.file_path).name
+ if not join_path.exists():
+ raise FileNotFoundError(f"FHIR join table not found: {join_path}")
+ logger.info("Joining FHIR table %s with %s", table_name, join_path)
+ join_df: dd.DataFrame = dd.read_parquet(
+ str(join_path), split_row_groups=True, blocksize="64MB"
+ ).replace("", pd.NA)
+ join_df = join_df.rename(columns=str.lower)
+ join_key = join_cfg.on.lower()
+ cols = [c.lower() for c in join_cfg.columns]
+ df = df.merge(join_df[[join_key] + cols], on=join_key, how=join_cfg.how)
+
+ ts_col = table_cfg.timestamp
+ if ts_col:
+ ts = (
+ functools.reduce(operator.add, (df[c].astype("string") for c in ts_col))
+ if isinstance(ts_col, list)
+ else df[ts_col].astype("string")
+ )
+ ts = dd.to_datetime(ts, format=table_cfg.timestamp_format, errors="coerce", utc=True)
+ df = df.assign(timestamp=ts.map_partitions(_strip_tz_to_naive_ms))
+ else:
+ df = df.assign(timestamp=pd.NaT)
+
+ if table_cfg.patient_id:
+ df = df.assign(patient_id=df[table_cfg.patient_id].astype("string"))
+ else:
+ df = df.reset_index(drop=True)
+ df = df.assign(patient_id=df.index.astype("string"))
+
+ df = df.dropna(subset=["patient_id"])
+ df = df.assign(event_type=table_name)
+ rename_attr = {attr.lower(): f"{table_name}/{attr}" for attr in table_cfg.attributes}
+ df = df.rename(columns=rename_attr)
+ return df[["patient_id", "event_type", "timestamp"] + [rename_attr[a.lower()] for a in table_cfg.attributes]]
+
+ @property
+ def unique_patient_ids(self) -> List[str]:
+ if self._unique_patient_ids is None:
+ self._unique_patient_ids = (
+ self.global_event_df
+ .select("patient_id")
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+ logger.info("Found %d unique patient IDs", len(self._unique_patient_ids))
+ return self._unique_patient_ids
+
+ def set_task(
+ self,
+ task: Any = None,
+ num_workers: Optional[int] = None,
+ input_processors: Optional[Any] = None,
+ output_processors: Optional[Any] = None,
+ ) -> Any:
+ self._main_guard(self.set_task.__name__)
+ if task is None:
+ raise ValueError("Pass a task instance, e.g. MPFClinicalPredictionTask(max_len=512).")
+
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ if isinstance(task, MPFClinicalPredictionTask):
+ worker_count = (
+ 1 if in_notebook() else (num_workers if num_workers is not None else self.num_workers)
+ )
+ warmup_pids = self._mpf_patient_ids_for_task(task)
+ patient_count = len(warmup_pids)
+ effective_workers = min(worker_count, patient_count) if patient_count else 1
+ ensure_special_tokens(self.vocab)
+ self._warm_mpf_vocabulary(task, warmup_pids)
+ task.frozen_vocab = effective_workers > 1
+ task.vocab = self.vocab
+ task._specials = ensure_special_tokens(self.vocab)
+
+ return super().set_task(task, num_workers, input_processors, output_processors)
+
+ def _mpf_patient_ids_for_task(self, task: Any) -> List[str]:
+ return (
+ task.pre_filter(self.global_event_df)
+ .select("patient_id")
+ .unique()
+ .collect(engine="streaming")
+ .to_series()
+ .sort()
+ .to_list()
+ )
+
+ def _warm_mpf_vocabulary(self, task: Any, patient_ids: List[str]) -> None:
+ clinical_cap = max(0, task.max_len - 2)
+ base = self.global_event_df
+ for batch in itertools.batched(patient_ids, 128):
+ patients = (
+ base.filter(pl.col("patient_id").is_in(batch))
+ .collect(engine="streaming")
+ .partition_by("patient_id", as_dict=True)
+ )
+ for patient_key, patient_df in patients.items():
+ warm_mpf_vocab_from_patient(
+ self.vocab, Patient(patient_id=patient_key[0], data_source=patient_df), clinical_cap
+ )
+
+ def gather_samples(self, task: Any) -> List[Dict[str, Any]]:
+ task.vocab = self.vocab
+ task._specials = None
+ task.frozen_vocab = False
+ samples: List[Dict[str, Any]] = []
+ for patient in self.iter_patients():
+ samples.extend(task(patient))
+ return samples
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 5233b1726..86486dab8 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -38,6 +38,7 @@
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
from .ehrmamba import EHRMamba, MambaBlock
+from .ehrmamba_cehr import EHRMambaCEHR
from .vae import VAE
from .vision_embedding import VisionEmbeddingModel
from .text_embedding import TextEmbedding
diff --git a/pyhealth/models/cehr_embeddings.py b/pyhealth/models/cehr_embeddings.py
new file mode 100644
index 000000000..7974a699e
--- /dev/null
+++ b/pyhealth/models/cehr_embeddings.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2024 Vector Institute / Odyssey authors
+#
+# Derived from Odyssey (https://github.com/VectorInstitute/odyssey):
+# odyssey/models/embeddings.py — MambaEmbeddingsForCEHR, TimeEmbeddingLayer, VisitEmbedding
+# Modifications: removed HuggingFace MambaConfig dependency; explicit constructor args.
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import torch
+from torch import nn
+
+
+class TimeEmbeddingLayer(nn.Module):
+ """Embedding layer for time features (sinusoidal)."""
+
+ def __init__(self, embedding_size: int, is_time_delta: bool = False):
+ super().__init__()
+ self.embedding_size = embedding_size
+ self.is_time_delta = is_time_delta
+ self.w = nn.Parameter(torch.empty(1, self.embedding_size))
+ self.phi = nn.Parameter(torch.empty(1, self.embedding_size))
+ nn.init.xavier_uniform_(self.w)
+ nn.init.xavier_uniform_(self.phi)
+
+ def forward(self, time_stamps: torch.Tensor) -> torch.Tensor:
+ if self.is_time_delta:
+ time_stamps = torch.cat(
+ (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]),
+ dim=-1,
+ )
+ time_stamps = time_stamps.float()
+ next_input = time_stamps.unsqueeze(-1) * self.w + self.phi
+ return torch.sin(next_input)
+
+
+class VisitEmbedding(nn.Module):
+ """Embedding layer for visit segments."""
+
+ def __init__(self, visit_order_size: int, embedding_size: int):
+ super().__init__()
+ self.embedding = nn.Embedding(visit_order_size, embedding_size)
+
+ def forward(self, visit_segments: torch.Tensor) -> torch.Tensor:
+ return self.embedding(visit_segments)
+
+
+class MambaEmbeddingsForCEHR(nn.Module):
+ """CEHR-style combined embeddings for Mamba (concept + type + time + age + visit)."""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ hidden_size: int,
+ pad_token_id: int = 0,
+ type_vocab_size: int = 9,
+ max_num_visits: int = 512,
+ time_embeddings_size: int = 32,
+ visit_order_size: int = 3,
+ layer_norm_eps: float = 1e-12,
+ hidden_dropout_prob: float = 0.1,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.pad_token_id = pad_token_id
+ self.type_vocab_size = type_vocab_size
+ self.max_num_visits = max_num_visits
+ self.word_embeddings = nn.Embedding(
+ vocab_size, hidden_size, padding_idx=pad_token_id
+ )
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
+ self.visit_order_embeddings = nn.Embedding(max_num_visits, hidden_size)
+ self.time_embeddings = TimeEmbeddingLayer(
+ embedding_size=time_embeddings_size, is_time_delta=True
+ )
+ self.age_embeddings = TimeEmbeddingLayer(
+ embedding_size=time_embeddings_size, is_time_delta=False
+ )
+ self.visit_segment_embeddings = VisitEmbedding(
+ visit_order_size=visit_order_size, embedding_size=hidden_size
+ )
+ self.scale_back_concat_layer = nn.Linear(
+ hidden_size + 2 * time_embeddings_size, hidden_size
+ )
+ self.tanh = nn.Tanh()
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+ self.dropout = nn.Dropout(hidden_dropout_prob)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ token_type_ids_batch: torch.Tensor,
+ time_stamps: torch.Tensor,
+ ages: torch.Tensor,
+ visit_orders: torch.Tensor,
+ visit_segments: torch.Tensor,
+ ) -> torch.Tensor:
+ inputs_embeds = self.word_embeddings(input_ids)
+ time_stamps_embeds = self.time_embeddings(time_stamps)
+ ages_embeds = self.age_embeddings(ages)
+ visit_segments_embeds = self.visit_segment_embeddings(visit_segments)
+ visit_order_embeds = self.visit_order_embeddings(visit_orders)
+ token_type_embeds = self.token_type_embeddings(token_type_ids_batch)
+ concat_in = torch.cat(
+ (inputs_embeds, time_stamps_embeds, ages_embeds), dim=-1
+ )
+ h = self.tanh(self.scale_back_concat_layer(concat_in))
+ embeddings = h + token_type_embeds + visit_order_embeds + visit_segments_embeds
+ embeddings = self.dropout(embeddings)
+ return self.LayerNorm(embeddings)
diff --git a/pyhealth/models/ehrmamba_cehr.py b/pyhealth/models/ehrmamba_cehr.py
new file mode 100644
index 000000000..cd555629c
--- /dev/null
+++ b/pyhealth/models/ehrmamba_cehr.py
@@ -0,0 +1,117 @@
+"""EHRMamba with CEHR-style embeddings for single-stream FHIR token sequences."""
+
+from __future__ import annotations
+
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+
+from pyhealth.datasets import SampleDataset
+
+from .base_model import BaseModel
+from .cehr_embeddings import MambaEmbeddingsForCEHR
+from .ehrmamba import MambaBlock
+from .utils import get_rightmost_masked_timestep
+
+
+class EHRMambaCEHR(BaseModel):
+ """Mamba backbone over CEHR embeddings (FHIR / MPF pipeline).
+
+ Args:
+ dataset: Fitted :class:`~pyhealth.datasets.SampleDataset` with MPF task schema.
+ vocab_size: Concept embedding vocabulary size (typically ``task.vocab.vocab_size``).
+ embedding_dim: Hidden size (``hidden_size`` in CEHR embeddings).
+ num_layers: Number of :class:`~pyhealth.models.ehrmamba.MambaBlock` layers.
+ pad_token_id: Padding id for masking (default 0).
+ state_size: SSM state size per channel.
+ conv_kernel: Causal conv kernel in each block.
+ dropout: Dropout before classifier.
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ vocab_size: int,
+ embedding_dim: int = 128,
+ num_layers: int = 2,
+ pad_token_id: int = 0,
+ state_size: int = 16,
+ conv_kernel: int = 4,
+ dropout: float = 0.1,
+ type_vocab_size: int = 16,
+ max_num_visits: int = 512,
+ time_embeddings_size: int = 32,
+ visit_segment_vocab: int = 3,
+ ):
+ super().__init__(dataset=dataset)
+ self.embedding_dim = embedding_dim
+ self.num_layers = num_layers
+ self.pad_token_id = pad_token_id
+ self.vocab_size = vocab_size
+
+ assert len(self.label_keys) == 1, "EHRMambaCEHR supports single label key only"
+ self.label_key = self.label_keys[0]
+ self.mode = self.dataset.output_schema[self.label_key]
+
+ self.embeddings = MambaEmbeddingsForCEHR(
+ vocab_size=vocab_size,
+ hidden_size=embedding_dim,
+ pad_token_id=pad_token_id,
+ type_vocab_size=type_vocab_size,
+ max_num_visits=max_num_visits,
+ time_embeddings_size=time_embeddings_size,
+ visit_order_size=visit_segment_vocab,
+ )
+ self.blocks = nn.ModuleList(
+ [
+ MambaBlock(
+ d_model=embedding_dim,
+ state_size=state_size,
+ conv_kernel=conv_kernel,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.dropout = nn.Dropout(dropout)
+ out_dim = self.get_output_size()
+ self.fc = nn.Linear(embedding_dim, out_dim)
+ self._forecasting_head: Optional[nn.Module] = None
+
+ def forward_forecasting(self, **kwargs: Any) -> Optional[torch.Tensor]:
+ """Optional next-token / forecasting head (extension point; not implemented)."""
+
+ return None
+
+ def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]:
+ concept_ids = kwargs["concept_ids"].to(self.device).long()
+ token_type_ids = kwargs["token_type_ids"].to(self.device).long()
+ time_stamps = kwargs["time_stamps"].to(self.device).float()
+ ages = kwargs["ages"].to(self.device).float()
+ visit_orders = kwargs["visit_orders"].to(self.device).long()
+ visit_segments = kwargs["visit_segments"].to(self.device).long()
+
+ x = self.embeddings(
+ input_ids=concept_ids,
+ token_type_ids_batch=token_type_ids,
+ time_stamps=time_stamps,
+ ages=ages,
+ visit_orders=visit_orders,
+ visit_segments=visit_segments,
+ )
+ mask = concept_ids != self.pad_token_id
+ for blk in self.blocks:
+ x = blk(x)
+ pooled = get_rightmost_masked_timestep(x, mask)
+ logits = self.fc(self.dropout(pooled))
+ y_true = kwargs[self.label_key].to(self.device).float()
+ if y_true.dim() == 1:
+ y_true = y_true.unsqueeze(-1)
+ loss = self.get_loss_function()(logits, y_true)
+ y_prob = self.prepare_y_prob(logits)
+ return {
+ "loss": loss,
+ "y_prob": y_prob,
+ "y_true": y_true,
+ "logit": logits,
+ }
diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py
index 67edc010e..45cd6608d 100644
--- a/pyhealth/models/utils.py
+++ b/pyhealth/models/utils.py
@@ -44,3 +44,31 @@ def get_last_visit(hidden_states, mask):
last_hidden_states = torch.gather(hidden_states, 1, last_visit)
last_hidden_state = last_hidden_states[:, 0, :]
return last_hidden_state
+
+
+def get_rightmost_masked_timestep(hidden_states, mask):
+ """Gather hidden state at the last True position in ``mask`` per row.
+
+ Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a
+ contiguous prefix; it picks the maximum index where ``mask`` is True.
+ Use for MPF / CEHR layouts where padding can appear between boundary tokens.
+
+ Args:
+ hidden_states: ``[batch, seq_len, hidden_size]``.
+ mask: ``[batch, seq_len]`` bool.
+
+ Returns:
+ Tensor ``[batch, hidden_size]``.
+ """
+ if mask is None:
+ return hidden_states[:, -1, :]
+ batch, seq_len, hidden = hidden_states.shape
+ device = hidden_states.device
+ idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(
+ batch, -1
+ )
+ idx_m = torch.where(mask, idx, torch.full_like(idx, -1))
+ last_idx = idx_m.max(dim=1).values.clamp(min=0)
+ last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden)
+ gathered = torch.gather(hidden_states, 1, last_idx)
+ return gathered[:, 0, :]
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 797988377..ffdf99560 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -66,3 +66,11 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
+
+
+def __getattr__(name: str):
+ if name == "MPFClinicalPredictionTask":
+ from .mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ return MPFClinicalPredictionTask
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py
new file mode 100644
index 000000000..34927d8ab
--- /dev/null
+++ b/pyhealth/tasks/mpf_clinical_prediction.py
@@ -0,0 +1,152 @@
+"""Multitask Prompted Fine-tuning (MPF) clinical prediction on FHIR timelines."""
+
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from pyhealth.data import Patient
+from pyhealth.datasets.fhir_cehr import (
+ ConceptVocab,
+ build_cehr_sequences,
+ ensure_special_tokens,
+ infer_mortality_label,
+)
+
+from .base_task import BaseTask
+
+
+def _pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return seq + [pad] * (max_len - len(seq))
+
+
+def _pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return seq + [pad] * (max_len - len(seq))
+
+
+def _left_pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return [pad] * (max_len - len(seq)) + seq
+
+
+def _left_pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return [pad] * (max_len - len(seq)) + seq
+
+
+class MPFClinicalPredictionTask(BaseTask):
+ """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens.
+
+ Works on :class:`~pyhealth.data.Patient` via the standard
+ ``global_event_df`` / :meth:`~pyhealth.datasets.MIMIC4FHIRDataset.set_task`
+ path. For :meth:`set_task`,
+ :class:`~pyhealth.datasets.MIMIC4FHIRDataset` reserves specials, warms concept
+ keys in the main process over the same patient cohort as
+ :meth:`~pyhealth.tasks.base_task.BaseTask.pre_filter` (including when LitData
+ skips ``_task_transform`` on cache hit), and sets :attr:`frozen_vocab` when
+ multiple workers run :meth:`~pyhealth.datasets.BaseDataset._task_transform` so
+ worker processes do not race on :class:`~pyhealth.datasets.mimic4_fhir.ConceptVocab`.
+
+ Attributes:
+ max_len: Truncated sequence length (must be >= 2 for boundary tokens).
+ use_mpf: If True, use ```` / ```` specials; else ```` / ````.
+ vocab: Shared concept vocabulary (usually the dataset's vocab).
+ frozen_vocab: If True, do not add new concept ids (post-warmup parallel path).
+ """
+
+ task_name: str = "MPFClinicalPredictionFHIR"
+ input_schema: Dict[str, Any] = {
+ "concept_ids": ("tensor", {"dtype": torch.long}),
+ "token_type_ids": ("tensor", {"dtype": torch.long}),
+ "time_stamps": "tensor",
+ "ages": "tensor",
+ "visit_orders": ("tensor", {"dtype": torch.long}),
+ "visit_segments": ("tensor", {"dtype": torch.long}),
+ }
+ output_schema: Dict[str, str] = {"label": "binary"}
+
+ def __init__(self, max_len: int = 512, use_mpf: bool = True) -> None:
+ if max_len < 2:
+ raise ValueError("max_len must be >= 2 for MPF boundary tokens")
+ self.max_len = max_len
+ self.use_mpf = use_mpf
+ self.vocab: Optional[ConceptVocab] = None
+ self._specials: Optional[Dict[str, int]] = None
+ self.frozen_vocab: bool = False
+
+ def _ensure_vocab(self) -> ConceptVocab:
+ if self.vocab is None:
+ self.vocab = ConceptVocab()
+ if self._specials is None:
+ self._specials = ensure_special_tokens(self.vocab)
+ return self.vocab
+
+ def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
+ """Build one labeled sample dict per patient.
+
+ Args:
+ patient: A tabular :class:`~pyhealth.data.Patient`.
+
+ Returns:
+ A one-element list with ``concept_ids``, tensor-ready feature lists, and
+ ``label`` (0/1). Boundary tokens are always included; when
+ ``max_len == 2`` the sequence is ````/```` and ```` only.
+ """
+ vocab = self._ensure_vocab()
+ pid = patient.patient_id
+ clinical_cap = max(0, self.max_len - 2)
+ (
+ concept_ids,
+ token_types,
+ time_stamps,
+ ages,
+ visit_orders,
+ visit_segments,
+ ) = build_cehr_sequences(
+ patient,
+ vocab,
+ clinical_cap,
+ grow_vocab=not self.frozen_vocab,
+ )
+
+ assert self._specials is not None
+ mor_id = self._specials[""] if self.use_mpf else self._specials[""]
+ reg_id = self._specials[""]
+ z0 = 0
+ zf = 0.0
+ concept_ids = [mor_id] + concept_ids + [reg_id]
+ token_types = [z0] + token_types + [z0]
+ time_stamps = [zf] + time_stamps + [zf]
+ ages = [zf] + ages + [zf]
+ visit_orders = [z0] + visit_orders + [z0]
+ visit_segments = [z0] + visit_segments + [z0]
+
+ ml = self.max_len
+ concept_ids = _left_pad_int(concept_ids, ml, vocab.pad_id)
+ token_types = _left_pad_int(token_types, ml, 0)
+ time_stamps = _left_pad_float(time_stamps, ml, 0.0)
+ ages = _left_pad_float(ages, ml, 0.0)
+ visit_orders = _left_pad_int(visit_orders, ml, 0)
+ visit_segments = _left_pad_int(visit_segments, ml, 0)
+
+ label = infer_mortality_label(patient)
+ return [
+ {
+ "patient_id": pid,
+ "visit_id": f"{pid}-0",
+ "concept_ids": concept_ids,
+ "token_type_ids": token_types,
+ "time_stamps": time_stamps,
+ "ages": ages,
+ "visit_orders": visit_orders,
+ "visit_segments": visit_segments,
+ "label": label,
+ }
+ ]
diff --git a/pyproject.toml b/pyproject.toml
index 934d4f1bb..58afeb1b7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,6 +45,7 @@ dependencies = [
"more-itertools~=10.8.0",
"einops>=0.8.0",
"linear-attention-transformer>=0.19.1",
+ "orjson~=3.10",
]
license = "BSD-3-Clause"
license-files = ["LICENSE.md"]
diff --git a/tests/core/mimic4_fhir_ndjson_fixtures.py b/tests/core/mimic4_fhir_ndjson_fixtures.py
new file mode 100644
index 000000000..2f48e2ba1
--- /dev/null
+++ b/tests/core/mimic4_fhir_ndjson_fixtures.py
@@ -0,0 +1,30 @@
+"""NDJSON file bodies for :mod:`tests.core.test_mimic4_fhir_dataset` (disk-only ingest)."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from pyhealth.datasets.fhir_ingest import (
+ synthetic_mpf_one_patient_ndjson_text,
+ synthetic_mpf_two_patient_ndjson_text,
+)
+
+
+def ndjson_one_patient_text() -> str:
+ return synthetic_mpf_one_patient_ndjson_text()
+
+
+def ndjson_two_class_text() -> str:
+ return synthetic_mpf_two_patient_ndjson_text()
+
+
+def write_two_class_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ path = directory / name
+ path.write_text(ndjson_two_class_text(), encoding="utf-8")
+ return path
+
+
+def write_one_patient_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ path = directory / name
+ path.write_text(ndjson_one_patient_text(), encoding="utf-8")
+ return path
diff --git a/tests/core/test_ehrmamba_cehr.py b/tests/core/test_ehrmamba_cehr.py
new file mode 100644
index 000000000..6c0ac79e0
--- /dev/null
+++ b/tests/core/test_ehrmamba_cehr.py
@@ -0,0 +1,124 @@
+import unittest
+
+import torch
+
+from pyhealth.datasets import create_sample_dataset, get_dataloader
+from pyhealth.models import EHRMambaCEHR
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+
+def _tiny_samples(seq: int = 16) -> tuple:
+ from pyhealth.datasets.fhir_cehr import ConceptVocab, ensure_special_tokens
+
+ task = MPFClinicalPredictionTask(max_len=seq, use_mpf=True)
+ task.vocab = ConceptVocab()
+ sp = ensure_special_tokens(task.vocab)
+ mid = task.vocab.add_token("test|filler")
+ samples = []
+ for lab in (0, 1):
+ samples.append(
+ {
+ "patient_id": f"p{lab}",
+ "visit_id": f"v{lab}",
+ "concept_ids": [sp[""]] + [mid] * (seq - 2) + [sp[""]],
+ "token_type_ids": [0] * seq,
+ "time_stamps": [0.0] * seq,
+ "ages": [50.0] * seq,
+ "visit_orders": [0] * seq,
+ "visit_segments": [0] * seq,
+ "label": lab,
+ }
+ )
+ return samples, task
+
+
+class TestEHRMambaCEHR(unittest.TestCase):
+ def test_readout_pools_rightmost_non_pad(self) -> None:
+ """MPF padding between tokens must not make pooling pick a pad position."""
+
+ from pyhealth.models.utils import (
+ get_last_visit,
+ get_rightmost_masked_timestep,
+ )
+
+ h = torch.tensor([[[1.0, 0.0], [2.0, 0.0], [0.0, 0.0], [99.0, 0.0]]])
+ m = torch.tensor([[True, True, False, True]])
+ out = get_rightmost_masked_timestep(h, m)
+ self.assertTrue(torch.allclose(out[0], torch.tensor([99.0, 0.0])))
+ wrong = get_last_visit(h, m)
+ self.assertFalse(torch.allclose(out[0], wrong[0]))
+
+ def test_end_to_end_fhir_pipeline(self) -> None:
+ import tempfile
+ from pathlib import Path
+
+ from pyhealth.datasets import MIMIC4FHIRDataset, create_sample_dataset
+ from pyhealth.datasets import get_dataloader
+
+ from tests.core.mimic4_fhir_ndjson_fixtures import write_two_class_ndjson
+
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=True)
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp
+ )
+ samples = ds.gather_samples(task)
+ sample_ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ dataset_name="fhir_test",
+ )
+ vocab_size = max(max(s["concept_ids"]) for s in samples) + 1
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=64,
+ num_layers=1,
+ )
+ batch = next(
+ iter(get_dataloader(sample_ds, batch_size=2, shuffle=False))
+ )
+ out = model(**batch)
+ self.assertIn("loss", out)
+ out["loss"].backward()
+
+ def test_forward_backward(self) -> None:
+ samples, task = _tiny_samples()
+ ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ )
+ vocab_size = max(max(s["concept_ids"]) for s in samples) + 1
+ model = EHRMambaCEHR(
+ dataset=ds,
+ vocab_size=vocab_size,
+ embedding_dim=64,
+ num_layers=1,
+ state_size=8,
+ )
+ batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False)))
+ out = model(**batch)
+ self.assertEqual(out["logit"].shape[0], 2)
+ out["loss"].backward()
+
+ def test_eval_mode(self) -> None:
+ samples, task = _tiny_samples()
+ ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ )
+ vocab_size = max(max(s["concept_ids"]) for s in samples) + 1
+ model = EHRMambaCEHR(dataset=ds, vocab_size=vocab_size, embedding_dim=32, num_layers=1)
+ model.eval()
+ with torch.no_grad():
+ batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False)))
+ out = model(**batch)
+ self.assertIn("y_prob", out)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_mimic4_fhir_dataset.py b/tests/core/test_mimic4_fhir_dataset.py
new file mode 100644
index 000000000..1611b1979
--- /dev/null
+++ b/tests/core/test_mimic4_fhir_dataset.py
@@ -0,0 +1,700 @@
+import gzip
+import tempfile
+import unittest
+from pathlib import Path
+from typing import Dict, List
+
+import orjson
+import polars as pl
+
+from pyhealth.data import Patient
+from pyhealth.datasets import MIMIC4FHIRDataset
+from pyhealth.datasets.fhir_ingest import (
+ _flatten_resource_to_table_row,
+ synthetic_mpf_two_patient_ndjson_text,
+)
+from pyhealth.datasets.fhir_cehr import (
+ ConceptVocab,
+ build_cehr_sequences,
+ collect_cehr_timeline_events,
+ infer_mortality_label,
+)
+
+from tests.core.mimic4_fhir_ndjson_fixtures import (
+ ndjson_two_class_text,
+ write_one_patient_ndjson,
+ write_two_class_ndjson,
+)
+
+
+def _third_patient_loinc_resources() -> List[Dict[str, object]]:
+ return [
+ {
+ "resourceType": "Patient",
+ "id": "p-synth-3",
+ "birthDate": "1960-01-01",
+ },
+ {
+ "resourceType": "Encounter",
+ "id": "e3",
+ "subject": {"reference": "Patient/p-synth-3"},
+ "period": {"start": "2020-08-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation",
+ "id": "o3",
+ "subject": {"reference": "Patient/p-synth-3"},
+ "encounter": {"reference": "Encounter/e3"},
+ "effectiveDateTime": "2020-08-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "999-9"}]},
+ },
+ ]
+
+
+def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ lines = synthetic_mpf_two_patient_ndjson_text().strip().split("\n")
+ lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources())
+ path = directory / name
+ path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return path
+
+
+def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient:
+ return Patient(patient_id=patient_id, data_source=pl.DataFrame(rows))
+
+
+class TestDeceasedBooleanFlattening(unittest.TestCase):
+ def test_string_false_not_coerced_by_python_bool(self) -> None:
+ """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``."""
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-str-false",
+ "deceasedBoolean": "false",
+ }
+ )
+ self.assertIsNotNone(row)
+ _table, payload = row
+ self.assertEqual(payload.get("deceased_boolean"), "false")
+
+ def test_string_true_parsed(self) -> None:
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-str-true",
+ "deceasedBoolean": "true",
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1].get("deceased_boolean"), "true")
+
+ def test_json_booleans_unchanged(self) -> None:
+ for raw, expected in ((True, "true"), (False, "false")):
+ with self.subTest(raw=raw):
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-bool",
+ "deceasedBoolean": raw,
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1].get("deceased_boolean"), expected)
+
+ def test_unknown_deceased_type_stored_as_none(self) -> None:
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-garbage",
+ "deceasedBoolean": {"unexpected": "object"},
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertIsNone(row[1].get("deceased_boolean"))
+
+ def test_infer_mortality_respects_string_false_row(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "event_type": "patient",
+ "timestamp": "2020-01-01T00:00:00",
+ "patient/deceased_boolean": "false",
+ },
+ ],
+ )
+ self.assertEqual(infer_mortality_label(patient), 0)
+
+
+class TestMIMIC4FHIRDataset(unittest.TestCase):
+ def test_concept_vocab_from_json_empty_token_to_id(self) -> None:
+ v = ConceptVocab.from_json({"token_to_id": {}})
+ self.assertIn("", v.token_to_id)
+ self.assertIn("", v.token_to_id)
+ self.assertEqual(v._next_id, 2)
+
+ def test_concept_vocab_from_json_empty_respects_next_id(self) -> None:
+ v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50})
+ self.assertEqual(v._next_id, 50)
+
+ def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None:
+ from pyhealth.datasets.fhir_ingest import sorted_ndjson_files
+
+ with tempfile.TemporaryDirectory() as tmp:
+ root = Path(tmp)
+ (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8")
+ (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8")
+ (root / "notes.txt").write_text("z", encoding="utf-8")
+ wide = sorted_ndjson_files(root, "**/*.ndjson.gz")
+ narrow = sorted_ndjson_files(
+ root,
+ ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"],
+ )
+ self.assertEqual(len(wide), 2)
+ self.assertEqual(len(narrow), 1)
+ self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz")
+
+ def test_dataset_accepts_glob_patterns_kwarg(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp
+ )
+ self.assertEqual(ds.glob_patterns, ["*.ndjson"])
+ _ = ds.global_event_df.collect(engine="streaming")
+
+ def test_dataset_rejects_both_glob_kwargs(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ with self.assertRaises(ValueError):
+ MIMIC4FHIRDataset(
+ root=tmp,
+ glob_pattern="*.ndjson",
+ glob_patterns=["*.ndjson"],
+ cache_dir=tmp,
+ )
+
+ def test_disk_fixture_resolves_events_per_patient(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ sub = ds.global_event_df.filter(pl.col("patient_id") == "p-synth-1").collect(
+ engine="streaming"
+ )
+ self.assertGreaterEqual(len(sub), 2)
+ self.assertIn("condition/concept_key", sub.columns)
+
+ def test_prepared_flat_tables_exist(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ _ = ds.global_event_df.collect(engine="streaming")
+ prepared = ds.prepared_tables_dir
+ self.assertTrue((prepared / "patient.parquet").is_file())
+ self.assertTrue((prepared / "encounter.parquet").is_file())
+ self.assertTrue((prepared / "condition.parquet").is_file())
+
+ def test_build_cehr_non_empty(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ task = MPFClinicalPredictionTask(max_len=64, use_mpf=True)
+ ds.gather_samples(task)
+ self.assertIsInstance(ds.vocab, ConceptVocab)
+ self.assertGreater(ds.vocab.vocab_size, 2)
+
+ def test_set_task_vocab_warm_on_litdata_cache_hit(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ task_kw = {"max_len": 64, "use_mpf": True}
+ ds1 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ ds1.set_task(MPFClinicalPredictionTask(**task_kw), num_workers=1)
+ warm_size = ds1.vocab.vocab_size
+ self.assertGreater(warm_size, 6)
+ ds2 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ ds2.set_task(MPFClinicalPredictionTask(**task_kw), num_workers=1)
+ self.assertEqual(ds2.vocab.vocab_size, warm_size)
+
+ def test_mortality_heuristic(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ samples = ds.gather_samples(MPFClinicalPredictionTask(max_len=64, use_mpf=False))
+ self.assertEqual({s["label"] for s in samples}, {0, 1})
+
+ def test_infer_deceased(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ dead = ds.get_patient("p-synth-2")
+ self.assertEqual(infer_mortality_label(dead), 1)
+
+ def test_disk_ndjson_gz_physionet_style(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ gz_path = Path(tmp) / "fixture.ndjson.gz"
+ with gzip.open(gz_path, "wt", encoding="utf-8") as gz:
+ gz.write(ndjson_two_class_text())
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5)
+ self.assertGreaterEqual(len(ds.unique_patient_ids), 1)
+
+ def test_disk_ndjson_temp_dir(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", max_patients=5)
+ self.assertEqual(len(ds.unique_patient_ids), 2)
+ samples = ds.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True))
+ self.assertGreaterEqual(len(samples), 1)
+ for sample in samples:
+ self.assertIn("concept_ids", sample)
+ self.assertIn("label", sample)
+
+ def test_global_event_df_schema_and_flattened_columns(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ df = ds.global_event_df.collect(engine="streaming")
+ self.assertGreater(len(df), 0)
+ self.assertIn("patient_id", df.columns)
+ self.assertIn("timestamp", df.columns)
+ self.assertIn("event_type", df.columns)
+ self.assertIn("condition/concept_key", df.columns)
+ self.assertIn("observation/concept_key", df.columns)
+ self.assertIn("patient/deceased_boolean", df.columns)
+
+ def test_set_task_parity_with_gather_samples_ndjson(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp), name="fx.ndjson")
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1
+ )
+ ref = sorted(
+ ds.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)),
+ key=lambda s: s["patient_id"],
+ )
+ sample_ds = ds.set_task(
+ MPFClinicalPredictionTask(max_len=48, use_mpf=True), num_workers=1
+ )
+ got = sorted(
+ [sample_ds[i] for i in range(len(sample_ds))],
+ key=lambda s: s["patient_id"],
+ )
+ self.assertEqual(len(got), len(ref))
+ for expected, actual in zip(ref, got):
+ self.assertEqual(expected["label"], int(actual["label"]))
+ actual_ids = actual["concept_ids"]
+ if hasattr(actual_ids, "tolist"):
+ actual_ids = actual_ids.tolist()
+ self.assertEqual(expected["concept_ids"], actual_ids)
+
+ def test_gather_samples_resets_frozen_vocab_after_set_task(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp_a, tempfile.TemporaryDirectory() as tmp_b:
+ write_two_class_ndjson(Path(tmp_a), name="a.ndjson")
+ write_two_class_ndjson(Path(tmp_b), name="b.ndjson")
+ ds_a = MIMIC4FHIRDataset(
+ root=tmp_a, glob_pattern="*.ndjson", cache_dir=tmp_a, num_workers=1
+ )
+ ds_b = MIMIC4FHIRDataset(
+ root=tmp_b, glob_pattern="*.ndjson", cache_dir=tmp_b, num_workers=1
+ )
+ task = MPFClinicalPredictionTask(max_len=48, use_mpf=True)
+ ds_a.set_task(task, num_workers=1)
+ self.assertFalse(task.frozen_vocab)
+
+ ref = sorted(
+ ds_b.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)),
+ key=lambda s: s["patient_id"],
+ )
+ got = sorted(ds_b.gather_samples(task), key=lambda s: s["patient_id"])
+ self.assertEqual(len(got), len(ref))
+ for expected, actual in zip(ref, got):
+ self.assertEqual(expected["label"], actual["label"])
+ self.assertEqual(expected["concept_ids"], actual["concept_ids"])
+
+ def test_set_task_multi_worker_sets_frozen_vocab(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2
+ )
+ task = MPFClinicalPredictionTask(max_len=48, use_mpf=True)
+ ds.set_task(task, num_workers=2)
+ self.assertTrue(task.frozen_vocab)
+
+ def test_mpf_pre_filter_vocab_warmup_excludes_dropped_patients(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ class TwoPatientMPFTask(MPFClinicalPredictionTask):
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id").is_in(["p-synth-1", "p-synth-2"]))
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_plus_third_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1
+ )
+ self.assertEqual(len(ds.unique_patient_ids), 3)
+ task = TwoPatientMPFTask(max_len=48, use_mpf=True)
+ ds.set_task(task, num_workers=1)
+ self.assertNotIn("http://loinc.org|999-9", ds.vocab.token_to_id)
+ self.assertIn("http://loinc.org|789-0", ds.vocab.token_to_id)
+
+ def test_mpf_pre_filter_patient_ids_drive_effective_workers(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ class OnePatientMPFTask(MPFClinicalPredictionTask):
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id") == "p-synth-1")
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2
+ )
+ task = OnePatientMPFTask(max_len=48, use_mpf=True)
+ warmup_pids = ds._mpf_patient_ids_for_task(task)
+ self.assertEqual(warmup_pids, ["p-synth-1"])
+ effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1
+ self.assertEqual(effective_workers, 1)
+
+ def test_encounter_reference_requires_exact_id(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-02T10:00:00",
+ "encounter/encounter_id": "e10",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-07-02T11:00:00",
+ "condition/encounter_id": "e10",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ tid = vocab["http://hl7.org/fhir/sid/icd-10-cm|I99"]
+ self.assertEqual(concept_ids.count(tid), 1)
+
+ def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "ea",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-01T10:00:00",
+ "encounter/encounter_id": "eb",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-15T12:00:00",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ self.assertEqual(concept_ids.count(vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]), 1)
+
+ def test_cehr_sequence_shapes(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ patient = ds.get_patient("p-synth-1")
+ vocab = ConceptVocab()
+ concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments = (
+ build_cehr_sequences(patient, vocab, max_len=32)
+ )
+ n = len(concept_ids)
+ self.assertEqual(len(token_types), n)
+ self.assertEqual(len(time_stamps), n)
+ self.assertEqual(len(ages), n)
+ self.assertEqual(len(visit_orders), n)
+ self.assertEqual(len(visit_segments), n)
+ self.assertGreater(n, 0)
+
+ def test_build_cehr_max_len_zero_no_clinical_tokens(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ c, _, _, _, _, vs = build_cehr_sequences(patient, vocab, max_len=0)
+ self.assertEqual(c, [])
+ self.assertEqual(vs, [])
+
+ def test_visit_segments_alternate_by_visit_index(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e0",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e0",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-07-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ _, _, _, _, _, visit_segments = build_cehr_sequences(patient, vocab, max_len=64)
+ self.assertEqual(visit_segments, [0, 0, 1, 1])
+
+ def test_unlinked_visit_idx_matches_sequential_counter(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": None,
+ "encounter/encounter_id": "e_bad",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-03-01T10:00:00",
+ "encounter/encounter_id": "e_ok",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-03-05T11:00:00",
+ "condition/encounter_id": "e_ok",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-03-15T12:00:00",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ concept_ids, _, _, _, visit_orders, visit_segments = build_cehr_sequences(
+ patient, vocab, max_len=64
+ )
+ i10 = vocab["http://hl7.org/fhir/sid/icd-10-cm|I10"]
+ z00 = vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]
+ i_link = concept_ids.index(i10)
+ i_free = concept_ids.index(z00)
+ self.assertEqual(visit_orders[i_link], visit_orders[i_free])
+ self.assertEqual(visit_segments[i_link], visit_segments[i_free])
+
+ def test_medication_request_uses_medication_codeable_concept(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T11:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T12:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ c, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111"
+ kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222"
+ self.assertNotEqual(vocab[ka], vocab[kb])
+ self.assertEqual(c.count(vocab[ka]), 1)
+ self.assertEqual(c.count(vocab[kb]), 1)
+
+ def test_medication_request_medication_reference_token(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T11:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "MedicationRequest/reference|med-abc",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ c, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ key = "MedicationRequest/reference|med-abc"
+ self.assertIn(vocab[key], c)
+ self.assertEqual(c.count(vocab[key]), 1)
+
+ def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "a|1",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "observation",
+ "timestamp": "2020-06-01T12:00:00",
+ "observation/encounter_id": "e1",
+ "observation/concept_key": "b|2",
+ },
+ ],
+ )
+ events = collect_cehr_timeline_events(patient)
+ self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_mpf_task.py b/tests/core/test_mpf_task.py
new file mode 100644
index 000000000..d130e2d83
--- /dev/null
+++ b/tests/core/test_mpf_task.py
@@ -0,0 +1,88 @@
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+
+from pyhealth.datasets import MIMIC4FHIRDataset
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+from tests.core.mimic4_fhir_ndjson_fixtures import write_two_class_ndjson
+
+
+class TestMPFClinicalPredictionTask(unittest.TestCase):
+ def _two_patient_ds(self) -> MIMIC4FHIRDataset:
+ tmp = tempfile.mkdtemp()
+ self.addCleanup(lambda p=tmp: shutil.rmtree(p, ignore_errors=True))
+ write_two_class_ndjson(Path(tmp))
+ return MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp
+ )
+
+ def test_max_len_validation(self) -> None:
+ with self.assertRaises(ValueError):
+ MPFClinicalPredictionTask(max_len=1, use_mpf=True)
+
+ def test_mpf_sets_boundary_tokens(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=True)
+ ds = self._two_patient_ds()
+ samples = ds.gather_samples(task)
+ vocab = ds.vocab
+ self.assertGreater(len(samples), 0)
+ s0 = samples[0]
+ mor = vocab[""]
+ reg = vocab[""]
+ pad_id = vocab.pad_id
+ ids = s0["concept_ids"]
+ first = next(i for i, x in enumerate(ids) if x != pad_id)
+ last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id)
+ self.assertEqual(ids[first], mor)
+ self.assertEqual(ids[last_nz], reg)
+ self.assertEqual(ids[-1], reg)
+
+ def test_no_mpf_uses_cls_reg(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=False)
+ ds = self._two_patient_ds()
+ samples = ds.gather_samples(task)
+ vocab = ds.vocab
+ s0 = samples[0]
+ cls_id = vocab[""]
+ reg_id = vocab[""]
+ pad_id = vocab.pad_id
+ ids = s0["concept_ids"]
+ first = next(i for i, x in enumerate(ids) if x != pad_id)
+ last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id)
+ self.assertEqual(ids[first], cls_id)
+ self.assertEqual(ids[last_nz], reg_id)
+ self.assertEqual(ids[-1], reg_id)
+
+ def test_schema_keys(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=16, use_mpf=True)
+ ds = self._two_patient_ds()
+ samples = ds.gather_samples(task)
+ for k in task.input_schema:
+ self.assertIn(k, samples[0])
+ self.assertIn("label", samples[0])
+
+ def test_max_len_two_keeps_boundary_tokens(self) -> None:
+ """``clinical_cap=0`` must yield ``[, ]`` left-padded, not truncated."""
+
+ task = MPFClinicalPredictionTask(max_len=2, use_mpf=True)
+ ds = self._two_patient_ds()
+ samples = ds.gather_samples(task)
+ vocab = ds.vocab
+ mor = vocab[""]
+ reg = vocab[""]
+ pad_id = vocab.pad_id
+ for s in samples:
+ ids = s["concept_ids"]
+ first = next(i for i, x in enumerate(ids) if x != pad_id)
+ last_nz = next(
+ i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id
+ )
+ self.assertEqual(ids[first], mor)
+ self.assertEqual(ids[last_nz], reg)
+ self.assertEqual(ids[-1], reg)
+
+
+if __name__ == "__main__":
+ unittest.main()