diff --git a/.gitignore b/.gitignore index aaf66fc36..0e3af028d 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,7 @@ pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model data/physionet.org/ # VSCode settings -.vscode/ \ No newline at end of file +.vscode/ + +# Benchmark runtime cache +benchmarks/.runtime_cache/ diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..fd3a809d2 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,46 @@ +# Loader Benchmark + +This benchmark compares PyHealth's streaming sample loader against the legacy in-memory loader using synthetic patient data. It focuses on: + +- Peak RAM usage (MB) using `tracemalloc` +- Wall-clock time (seconds) using `time.perf_counter` +- Throughput (patients/second) + +Why this matters: +- The streaming loader is designed to reduce memory pressure on larger datasets. +- The legacy in-memory loader can be faster on small datasets but may use more RAM. +- This benchmark provides a reproducible baseline for tradeoff decisions by dataset size. + +## How to run + +```bash +python benchmarks/loader_benchmark.py +``` + +The script benchmarks three scales by default: +- `small`: 100 patients +- `medium`: 1,000 patients +- `large`: 5,000 patients + +Outputs: +- `benchmarks/results.csv` +- `benchmarks/benchmark_chart.png` + +## Sample terminal output + +```text + scale num_patients loader dataset_class status wall_time_sec peak_ram_mb throughput_patients_per_sec note + small 100 legacy_in_memory InMemorySampleDataset ok 0.0022 0.1253 45,313.8046 + small 100 streaming SampleDataset ok 7.7841 1.3903 12.8466 +medium 1000 legacy_in_memory InMemorySampleDataset ok 0.0155 0.8346 64,692.1166 +medium 1000 streaming SampleDataset ok 8.2625 1.3849 121.0291 + large 5000 legacy_in_memory InMemorySampleDataset ok 0.0774 4.0510 64,578.5552 + large 5000 streaming SampleDataset ok 11.8215 5.5277 422.9585 +``` + +If streaming mode is unavailable in the environment, streaming rows are marked as `skipped` with a note, and the script still completes successfully. + +## Key findings + +- Placeholder: summarize RAM and time differences after running in your environment. +- Placeholder: note when streaming becomes beneficial by scale. diff --git a/benchmarks/benchmark_chart.png b/benchmarks/benchmark_chart.png new file mode 100644 index 000000000..aadb49da3 Binary files /dev/null and b/benchmarks/benchmark_chart.png differ diff --git a/benchmarks/loader_benchmark.py b/benchmarks/loader_benchmark.py new file mode 100644 index 000000000..a5c0ad02f --- /dev/null +++ b/benchmarks/loader_benchmark.py @@ -0,0 +1,468 @@ +#!/usr/bin/env python3 +"""Benchmark PyHealth streaming and legacy sample loaders. + +This script compares: +1) Streaming loader: ``SampleDataset`` (disk-backed) +2) Legacy loader: ``InMemorySampleDataset`` (in-memory) + +It benchmarks synthetic patient samples at: +- small: 100 patients +- medium: 1,000 patients +- large: 5,000 patients + +Metrics: +- Peak RAM via ``tracemalloc`` (MB) +- Wall-clock time via ``time.perf_counter`` (seconds) +- Throughput (patients/second) +""" + +from __future__ import annotations + +import argparse +import contextlib +import gc +import inspect +import os +import sys +import time +import tracemalloc +import types +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence + +# Keep runtime cache local/writable for sandboxed environments. +RUNTIME_CACHE_DIR = Path("benchmarks/.runtime_cache") +RUNTIME_CACHE_DIR.mkdir(parents=True, exist_ok=True) +RUNTIME_MPL_DIR = RUNTIME_CACHE_DIR / "matplotlib" +RUNTIME_MPL_DIR.mkdir(parents=True, exist_ok=True) +os.environ.setdefault("XDG_CACHE_HOME", str(RUNTIME_CACHE_DIR.resolve())) +os.environ.setdefault("MPLCONFIGDIR", str(RUNTIME_MPL_DIR.resolve())) +# PyHealth uses Path.home()/.cache directly; set HOME to local runtime cache. +RUNTIME_HOME = (RUNTIME_CACHE_DIR / "home").resolve() +RUNTIME_HOME.mkdir(parents=True, exist_ok=True) +os.environ["HOME"] = str(RUNTIME_HOME) + +# Some environments may not ship Python's optional _lzma module. +# TorchVision imports lzma at import-time; provide a minimal fallback so this +# benchmark can still run when lzma compression is not used. +try: + import lzma # noqa: F401 +except ModuleNotFoundError: + lzma_stub = types.ModuleType("lzma") + + def _missing_lzma(*_args, **_kwargs): + raise ModuleNotFoundError( + "lzma support is unavailable in this Python build." + ) + + class _MissingLZMAFile: + def __init__(self, *_args, **_kwargs): + _missing_lzma() + + class _MissingLZMAError(Exception): + pass + + lzma_stub.open = _missing_lzma # type: ignore[attr-defined] + lzma_stub.LZMAFile = _MissingLZMAFile # type: ignore[attr-defined] + lzma_stub.LZMAError = _MissingLZMAError # type: ignore[attr-defined] + sys.modules["lzma"] = lzma_stub + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import pyhealth.datasets as datasets_module +from pyhealth.datasets import ( + BaseDataset, + InMemorySampleDataset, + SampleDataset, + create_sample_dataset, + get_dataloader, +) + + +SCALES: Sequence[tuple[str, int]] = ( + ("small", 100), + ("medium", 1_000), + ("large", 5_000), +) +INPUT_SCHEMA = {"feature": "raw"} +OUTPUT_SCHEMA = {"label": "raw"} +DEFAULT_BATCH_SIZE = 256 +DEFAULT_RESULTS_CSV = Path("benchmarks/results.csv") +DEFAULT_CHART_PATH = Path("benchmarks/benchmark_chart.png") + +LEGACY_LOADER = "legacy_in_memory" +STREAMING_LOADER = "streaming" + + +@dataclass +class BenchmarkRow: + scale: str + num_patients: int + loader: str + dataset_class: str + status: str + wall_time_sec: float | None + peak_ram_mb: float | None + throughput_patients_per_sec: float | None + note: str + + +def parse_sizes(raw: str) -> List[int]: + """Parse comma-separated patient counts.""" + parsed: List[int] = [] + for token in raw.split(","): + token = token.strip().replace("_", "") + if not token: + continue + parsed.append(int(token)) + if not parsed: + raise ValueError("No valid sizes provided.") + return parsed + + +def generate_samples(num_patients: int) -> List[Dict[str, Any]]: + """Generate synthetic sample data with one sample per patient.""" + return [ + { + "patient_id": f"p{i}", + "record_id": f"r{i}", + "feature": [i % 17, (i * 3) % 23, (i * 7) % 31], + "label": i % 2, + } + for i in range(num_patients) + ] + + +def count_batch_patients(batch: Dict[str, Any]) -> int: + """Count patients in a collated batch.""" + if "patient_id" in batch: + return len(batch["patient_id"]) + first_value = next(iter(batch.values())) + return len(first_value) + + +def discover_streaming_supported_datasets() -> List[str]: + """Discover dataset classes that inherit BaseDataset.""" + supported: List[str] = [] + for name, obj in inspect.getmembers(datasets_module, inspect.isclass): + if obj in (BaseDataset, SampleDataset, InMemorySampleDataset): + continue + if issubclass(obj, BaseDataset): + supported.append(name) + return sorted(set(supported)) + + +def summarize_loader_apis() -> Dict[str, str]: + """Summarize loader signatures to document API differences.""" + create_sig = inspect.signature(create_sample_dataset) + streaming_sig = inspect.signature(SampleDataset.__init__) + legacy_sig = inspect.signature(InMemorySampleDataset.__init__) + return { + "create_sample_dataset": str(create_sig), + "streaming_loader": f"SampleDataset{streaming_sig}", + "legacy_loader": f"InMemorySampleDataset{legacy_sig}", + "supports_in_memory_flag": str("in_memory" in create_sig.parameters), + } + + +def print_codebase_exploration() -> None: + """Print codebase-derived info before running benchmarks.""" + supported = discover_streaming_supported_datasets() + api = summarize_loader_apis() + + print("\nCodebase Exploration") + print("====================") + print( + f"Streaming-capable dataset classes (BaseDataset subclasses): {len(supported)}" + ) + print(", ".join(supported)) + print("\nLoader API differences:") + print(f"- Factory helper: create_sample_dataset{api['create_sample_dataset']}") + print(f"- Streaming loader: {api['streaming_loader']}") + print(f"- Legacy loader: {api['legacy_loader']}") + print( + "- Mode parameter: " + f"create_sample_dataset(..., in_memory=) -> " + "False: streaming / True: legacy in-memory" + ) + + +def _benchmark_one( + scale: str, + num_patients: int, + samples: List[Dict[str, Any]], + loader_name: str, + in_memory: bool, + batch_size: int, +) -> BenchmarkRow: + dataset = None + dataloader = None + processed = 0 + status = "ok" + note = "" + dataset_class = "" + + gc.collect() + tracemalloc.start() + start = time.perf_counter() + + try: + with open(os.devnull, "w", encoding="utf-8") as sink: + with contextlib.redirect_stdout(sink), contextlib.redirect_stderr(sink): + dataset = create_sample_dataset( + samples=samples, + input_schema=INPUT_SCHEMA, + output_schema=OUTPUT_SCHEMA, + dataset_name="loader_benchmark", + task_name=f"{scale}_{loader_name}", + in_memory=in_memory, + ) + dataset_class = dataset.__class__.__name__ + + dataloader = get_dataloader( + dataset=dataset, + batch_size=min(batch_size, max(1, num_patients)), + shuffle=False, + ) + for batch in dataloader: + processed += count_batch_patients(batch) + except Exception as exc: + status = "skipped" if loader_name == STREAMING_LOADER else "error" + note = f"{type(exc).__name__}: {exc}" + finally: + wall_time = time.perf_counter() - start + _, peak_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + if dataset is not None and hasattr(dataset, "close"): + dataset.close() + del dataloader + del dataset + gc.collect() + + if status != "ok": + return BenchmarkRow( + scale=scale, + num_patients=num_patients, + loader=loader_name, + dataset_class=dataset_class or "n/a", + status=status, + wall_time_sec=None, + peak_ram_mb=None, + throughput_patients_per_sec=None, + note=note, + ) + + wall_time = float(wall_time) + peak_ram_mb = peak_bytes / (1024**2) + throughput = processed / wall_time if wall_time > 0 else None + return BenchmarkRow( + scale=scale, + num_patients=num_patients, + loader=loader_name, + dataset_class=dataset_class, + status=status, + wall_time_sec=wall_time, + peak_ram_mb=peak_ram_mb, + throughput_patients_per_sec=throughput, + note=note, + ) + + +def run_benchmark(sizes: Iterable[int], batch_size: int) -> pd.DataFrame: + """Run benchmark for both loaders on each scale.""" + label_for_size = {size: label for label, size in SCALES} + records: List[BenchmarkRow] = [] + + streaming_available = True + streaming_skip_note = "" + + for size in sizes: + scale_label = label_for_size.get(size, f"custom_{size}") + samples = generate_samples(size) + + records.append( + _benchmark_one( + scale=scale_label, + num_patients=size, + samples=samples, + loader_name=LEGACY_LOADER, + in_memory=True, + batch_size=batch_size, + ) + ) + + if streaming_available: + streaming_row = _benchmark_one( + scale=scale_label, + num_patients=size, + samples=samples, + loader_name=STREAMING_LOADER, + in_memory=False, + batch_size=batch_size, + ) + records.append(streaming_row) + if streaming_row.status != "ok": + streaming_available = False + streaming_skip_note = ( + streaming_row.note + or "Streaming mode unavailable in current environment." + ) + else: + records.append( + BenchmarkRow( + scale=scale_label, + num_patients=size, + loader=STREAMING_LOADER, + dataset_class="n/a", + status="skipped", + wall_time_sec=None, + peak_ram_mb=None, + throughput_patients_per_sec=None, + note=streaming_skip_note, + ) + ) + + df = pd.DataFrame(asdict(row) for row in records) + df = df.sort_values(["num_patients", "loader"]).reset_index(drop=True) + return df + + +def format_results_table(df: pd.DataFrame) -> pd.DataFrame: + """Format values for terminal display.""" + display_df = df.copy() + for col in ["wall_time_sec", "peak_ram_mb", "throughput_patients_per_sec"]: + display_df[col] = display_df[col].map( + lambda x: "-" if pd.isna(x) else f"{x:,.4f}" + ) + return display_df + + +def _metric_values( + df: pd.DataFrame, scales: List[str], loader: str, metric: str +) -> List[float]: + values: List[float] = [] + for scale in scales: + row = df[ + (df["scale"] == scale) & (df["loader"] == loader) & (df["status"] == "ok") + ] + if row.empty: + values.append(np.nan) + else: + values.append(float(row.iloc[0][metric])) + return values + + +def plot_results(df: pd.DataFrame, output_path: Path) -> None: + """Save a bar chart comparing RAM and wall time.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + scales = [label for label, _ in SCALES if label in set(df["scale"].tolist())] + if not scales: + scales = sorted(df["scale"].unique().tolist()) + + metrics = [ + ("peak_ram_mb", "Peak RAM (MB)"), + ("wall_time_sec", "Wall Time (s)"), + ] + loaders = [ + (LEGACY_LOADER, "Legacy in-memory"), + (STREAMING_LOADER, "Streaming"), + ] + x = np.arange(len(scales)) + width = 0.35 + + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + for ax, (metric, title) in zip(axes, metrics): + for idx, (loader_key, loader_label) in enumerate(loaders): + offset = (idx - 0.5) * width + values = _metric_values(df, scales, loader_key, metric) + valid_points = [(i, v) for i, v in enumerate(values) if not np.isnan(v)] + if not valid_points: + continue + x_positions = [x[i] + offset for i, _ in valid_points] + y_values = [v for _, v in valid_points] + ax.bar(x_positions, y_values, width=width, label=loader_label) + + ax.set_title(title) + ax.set_xticks(x) + ax.set_xticklabels(scales) + ax.set_xlabel("Scale") + ax.grid(axis="y", alpha=0.3) + + axes[0].set_ylabel("Value") + handles, labels = axes[0].get_legend_handles_labels() + if handles: + fig.legend(handles, labels, loc="upper center", ncol=2, frameon=False) + + streaming_skipped = df[ + (df["loader"] == STREAMING_LOADER) & (df["status"] != "ok") + ] + if not streaming_skipped.empty: + note = streaming_skipped.iloc[0]["note"] + fig.text( + 0.01, + 0.01, + f"Note: streaming benchmark skipped in this environment. {note}", + fontsize=9, + ) + + fig.tight_layout(rect=[0, 0.04, 1, 0.92]) + fig.savefig(output_path, dpi=200) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark PyHealth streaming vs legacy in-memory loaders." + ) + parser.add_argument( + "--sizes", + type=str, + default="100,1000,5000", + help="Comma-separated patient counts (default: 100,1000,5000).", + ) + parser.add_argument( + "--batch-size", + type=int, + default=DEFAULT_BATCH_SIZE, + help=f"DataLoader batch size (default: {DEFAULT_BATCH_SIZE}).", + ) + parser.add_argument( + "--csv-out", + type=Path, + default=DEFAULT_RESULTS_CSV, + help=f"CSV output path (default: {DEFAULT_RESULTS_CSV}).", + ) + parser.add_argument( + "--chart-out", + type=Path, + default=DEFAULT_CHART_PATH, + help=f"Chart output path (default: {DEFAULT_CHART_PATH}).", + ) + args = parser.parse_args() + + sizes = parse_sizes(args.sizes) + print_codebase_exploration() + + results_df = run_benchmark(sizes=sizes, batch_size=args.batch_size) + args.csv_out.parent.mkdir(parents=True, exist_ok=True) + results_df.to_csv(args.csv_out, index=False) + + plot_results(results_df, args.chart_out) + + print("\nBenchmark Results") + print("=================") + print(format_results_table(results_df).to_string(index=False)) + print(f"\nSaved CSV: {args.csv_out}") + print(f"Saved chart: {args.chart_out}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/loader_benchmark_comparison.png b/benchmarks/loader_benchmark_comparison.png new file mode 100644 index 000000000..c10f0b789 Binary files /dev/null and b/benchmarks/loader_benchmark_comparison.png differ diff --git a/benchmarks/loader_benchmark_results.csv b/benchmarks/loader_benchmark_results.csv new file mode 100644 index 000000000..fea8fcd29 --- /dev/null +++ b/benchmarks/loader_benchmark_results.csv @@ -0,0 +1,7 @@ +loader,num_patients,wall_time_sec,peak_ram_mb,throughput_patients_per_sec,processed_patients,batch_size +in_memory,1000,0.015645167004549876,0.8375463485717773,63917.502428013926,1000,256 +streaming,1000,7.462171249993844,2.077880859375,134.00925367409988,1000,256 +in_memory,10000,0.15010529098799452,8.080141067504883,66619.90349693804,10000,256 +streaming,10000,17.429230958980042,9.584845542907715,573.7487800543323,10000,256 +in_memory,100000,1.6128525000240188,84.10364246368408,62001.94996040294,100000,256 +streaming,100000,157.4611568749824,101.28286266326904,635.0772595897782,100000,256 diff --git a/benchmarks/results.csv b/benchmarks/results.csv new file mode 100644 index 000000000..f4516c9e7 --- /dev/null +++ b/benchmarks/results.csv @@ -0,0 +1,7 @@ +scale,num_patients,loader,dataset_class,status,wall_time_sec,peak_ram_mb,throughput_patients_per_sec,note +small,100,legacy_in_memory,InMemorySampleDataset,ok,0.00425804202677682,0.18910980224609375,23484.972522851374, +small,100,streaming,SampleDataset,ok,7.591603999986546,1.4529895782470703,13.17244682417276, +medium,1000,legacy_in_memory,InMemorySampleDataset,ok,0.015434042026754469,0.8984231948852539,64791.841195360794, +medium,1000,streaming,SampleDataset,ok,7.6678521250141785,1.4480886459350586,130.41461724826246, +large,5000,legacy_in_memory,InMemorySampleDataset,ok,0.08102108401362784,4.114830017089844,61712.33156000473, +large,5000,streaming,SampleDataset,ok,11.755720708024455,5.590873718261719,425.32483751396035, diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index effb47133..97e8d78e5 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -61,7 +61,12 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset +from .sample_dataset import ( + InMemorySampleDataset, + SampleBuilder, + SampleDataset, + create_sample_dataset, +) from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset diff --git a/pyhealth/metrics/multilabel.py b/pyhealth/metrics/multilabel.py index 03efbcfab..24ba7e7d2 100644 --- a/pyhealth/metrics/multilabel.py +++ b/pyhealth/metrics/multilabel.py @@ -3,7 +3,6 @@ import numpy as np import sklearn.metrics as sklearn_metrics -from pyhealth.medcode import ATC import pyhealth.metrics.calibration as calib from pyhealth.metrics import ddi_rate_score from pyhealth import BASE_CACHE_PATH as CACHE_PATH @@ -68,7 +67,7 @@ def multilabel_metrics_fn( y_true: True target values of shape (n_samples, n_labels). y_prob: Predicted probabilities of shape (n_samples, n_labels). metrics: List of metrics to compute. Default is ["pr_auc_samples"]. - threshold: Threshold to binarize the predicted probabilities. Default is 0.5. + threshold: Threshold to binarize the predicted probabilities. Default is 0.3. Returns: Dictionary of metrics whose keys are the metric names and values are @@ -207,8 +206,12 @@ def multilabel_metrics_fn( output["hamming_loss"] = hamming_loss elif metric == "ddi": ddi_adj = np.load(os.path.join(CACHE_PATH, 'ddi_adj.npy')) - y_pred = [np.where(item)[0] for item in y_pred] - output["ddi_score"] = ddi_rate_score(y_pred, ddi_adj) + pred_labels = [np.where(item)[0] for item in y_pred] + ddi_score = ddi_rate_score(pred_labels, ddi_adj) + # Keep "ddi" aligned with the requested metric name while preserving + # the historical key for backward compatibility. + output["ddi"] = ddi_score + output["ddi_score"] = ddi_score elif metric in {"cwECE", "cwECE_adapt"}: output[metric] = calib.ece_classwise( y_prob, diff --git a/pyhealth/metrics/ranking.py b/pyhealth/metrics/ranking.py index b19f5107d..551acfd4f 100644 --- a/pyhealth/metrics/ranking.py +++ b/pyhealth/metrics/ranking.py @@ -1,4 +1,5 @@ from typing import List, Dict +from numbers import Integral def ranking_metrics_fn(qrels: Dict[str, Dict[str, int]], @@ -31,11 +32,24 @@ def ranking_metrics_fn(qrels: Dict[str, Dict[str, int]], >>> ranking_metrics_fn(qrels, results, k_values) {'NDCG@1': 0.5, 'MAP@1': 0.25, 'Recall@1': 0.25, 'P@1': 0.5, 'NDCG@2': 0.5, 'MAP@2': 0.375, 'Recall@2': 0.5, 'P@2': 0.5} """ + if not qrels: + raise ValueError("qrels must not be empty.") + if not results: + raise ValueError("results must not be empty.") + if not k_values: + raise ValueError("k_values must not be empty.") + if any(not isinstance(k, Integral) or int(k) <= 0 for k in k_values): + raise ValueError("k_values must contain only positive integers.") + + k_values = [int(k) for k in k_values] + try: import pytrec_eval - except: - raise ImportError("pytrec_eval is not installed. Please install it manually by running \ - 'pip install pytrec_eval'.") + except ModuleNotFoundError as exc: + raise ImportError( + "pytrec_eval is not installed. Please install it manually with " + "'pip install pytrec_eval'." + ) from exc ret = {} for k in k_values: @@ -52,6 +66,13 @@ def ranking_metrics_fn(qrels: Dict[str, Dict[str, int]], {map_string, ndcg_string, recall_string, precision_string}) scores = evaluator.evaluate(results) + if not scores: + raise ValueError( + "No ranking scores were produced. Ensure results contain query ids " + "present in qrels." + ) + + num_queries = len(scores) for query_id in scores.keys(): for k in k_values: @@ -61,10 +82,10 @@ def ranking_metrics_fn(qrels: Dict[str, Dict[str, int]], ret[f"P@{k}"] += scores[query_id]["P_" + str(k)] for k in k_values: - ret[f"NDCG@{k}"] = round(ret[f"NDCG@{k}"] / len(scores), 5) - ret[f"MAP@{k}"] = round(ret[f"MAP@{k}"] / len(scores), 5) - ret[f"Recall@{k}"] = round(ret[f"Recall@{k}"] / len(scores), 5) - ret[f"P@{k}"] = round(ret[f"P@{k}"] / len(scores), 5) + ret[f"NDCG@{k}"] = round(ret[f"NDCG@{k}"] / num_queries, 5) + ret[f"MAP@{k}"] = round(ret[f"MAP@{k}"] / num_queries, 5) + ret[f"Recall@{k}"] = round(ret[f"Recall@{k}"] / num_queries, 5) + ret[f"P@{k}"] = round(ret[f"P@{k}"] / num_queries, 5) return ret diff --git a/pyhealth/metrics/regression.py b/pyhealth/metrics/regression.py index 37e454c70..24e590ec8 100644 --- a/pyhealth/metrics/regression.py +++ b/pyhealth/metrics/regression.py @@ -16,7 +16,7 @@ def regression_metrics_fn( - kl_divergence: KL divergence - mse: mean squared error - mae: mean absolute error - If no metrics are specified, kd_div, mse, mae are computed by default. + If no metrics are specified, kl_divergence, mse, and mae are computed by default. This function calls sklearn.metrics functions to compute the metrics. For more information on the metrics, please refer to the documentation of the @@ -32,11 +32,11 @@ def regression_metrics_fn( the metric values. Examples: - >>> from pyhealth.metrics import binary_metrics_fn - >>> y_true = np.array([0, 0, 1, 1]) - >>> y_prob = np.array([0.1, 0.4, 0.35, 0.8]) - >>> binary_metrics_fn(y_true, y_prob, metrics=["accuracy"]) - {'accuracy': 0.75} + >>> from pyhealth.metrics import regression_metrics_fn + >>> x = np.array([1.0, 2.0, 3.0]) + >>> x_rec = np.array([1.2, 1.8, 2.9]) + >>> regression_metrics_fn(x, x_rec, metrics=["mse"]) + {'mse': 0.03} """ if metrics is None: metrics = ["kl_divergence", "mse", "mae"] @@ -50,11 +50,11 @@ def regression_metrics_fn( output = {} for metric in metrics: if metric == "kl_divergence": - x[x < 1e-6] = 1e-6 - x_rec[x_rec < 1e-6] = 1e-6 - x = x / np.sum(x) - x_rec = x_rec / np.sum(x_rec) - kl_divergence = np.sum(x_rec * np.log(x_rec / x)) + x_safe = np.maximum(x, 1e-6) + x_rec_safe = np.maximum(x_rec, 1e-6) + x_dist = x_safe / np.sum(x_safe) + x_rec_dist = x_rec_safe / np.sum(x_rec_safe) + kl_divergence = np.sum(x_rec_dist * np.log(x_rec_dist / x_dist)) output["kl_divergence"] = kl_divergence elif metric == "mse": mse = sklearn_metrics.mean_squared_error(x, x_rec) diff --git a/tests/core/test_metrics_quality.py b/tests/core/test_metrics_quality.py new file mode 100644 index 000000000..7dd54a867 --- /dev/null +++ b/tests/core/test_metrics_quality.py @@ -0,0 +1,131 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path +from unittest.mock import patch + +import numpy as np + + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def _load_module( + module_name: str, relative_path: str, stubs: dict[str, object] | None = None +): + module_path = REPO_ROOT / relative_path + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load module spec for {module_path}") + module = importlib.util.module_from_spec(spec) + with patch.dict(sys.modules, stubs or {}, clear=False): + spec.loader.exec_module(module) + return module + + +def _load_multilabel_module(ddi_rate_score_fn): + fake_pyhealth = types.ModuleType("pyhealth") + fake_pyhealth.BASE_CACHE_PATH = "/tmp" + + fake_metrics = types.ModuleType("pyhealth.metrics") + fake_metrics.__path__ = [] + fake_metrics.ddi_rate_score = ddi_rate_score_fn + + fake_calibration = types.ModuleType("pyhealth.metrics.calibration") + fake_calibration.ece_classwise = lambda *args, **kwargs: 0.0 + fake_metrics.calibration = fake_calibration + + stubs = { + "pyhealth": fake_pyhealth, + "pyhealth.metrics": fake_metrics, + "pyhealth.metrics.calibration": fake_calibration, + } + return _load_module( + "multilabel_metrics_under_test", "pyhealth/metrics/multilabel.py", stubs + ) + + +class TestRegressionMetricsQuality(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.regression = _load_module( + "regression_metrics_under_test", "pyhealth/metrics/regression.py" + ) + + def test_kl_divergence_does_not_change_mse(self): + x = np.array([0.2, 0.3, 0.5], dtype=float) + x_rec = np.array([0.1, 0.7, 0.2], dtype=float) + + mse_only = self.regression.regression_metrics_fn(x, x_rec, metrics=["mse"])[ + "mse" + ] + mse_with_kl = self.regression.regression_metrics_fn( + x, x_rec, metrics=["kl_divergence", "mse"] + )["mse"] + + self.assertAlmostEqual(mse_with_kl, mse_only, places=12) + + +class TestMultilabelMetricsQuality(unittest.TestCase): + def test_ddi_metric_does_not_break_followup_metrics(self): + def fake_ddi_rate_score(pred_labels, ddi_adj): + return 0.125 + + multilabel = _load_multilabel_module(fake_ddi_rate_score) + y_true = np.array([[1, 0, 1], [0, 1, 0]]) + y_prob = np.array([[0.8, 0.4, 0.9], [0.3, 0.7, 0.2]]) + + with patch.object(multilabel.np, "load", return_value=np.zeros((3, 3))): + scores = multilabel.multilabel_metrics_fn( + y_true, + y_prob, + metrics=["ddi", "f1_micro"], + threshold=0.5, + ) + + self.assertIn("ddi", scores) + self.assertIn("ddi_score", scores) + self.assertIn("f1_micro", scores) + self.assertEqual(scores["ddi"], scores["ddi_score"]) + + +class TestRankingMetricsQuality(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.ranking = _load_module( + "ranking_metrics_under_test", "pyhealth/metrics/ranking.py" + ) + + def test_invalid_k_values_raise_value_error(self): + with self.assertRaisesRegex(ValueError, "k_values"): + self.ranking.ranking_metrics_fn( + {"q1": {"d1": 1}}, + {"q1": {"d1": 1.0}}, + [0], + ) + + def test_empty_scores_raise_clear_error(self): + fake_pytrec_eval = types.ModuleType("pytrec_eval") + + class FakeEvaluator: + def __init__(self, qrels, metrics): + self.qrels = qrels + self.metrics = metrics + + def evaluate(self, results): + return {} + + fake_pytrec_eval.RelevanceEvaluator = FakeEvaluator + + with patch.dict(sys.modules, {"pytrec_eval": fake_pytrec_eval}, clear=False): + with self.assertRaisesRegex(ValueError, "No ranking scores were produced"): + self.ranking.ranking_metrics_fn( + {"q1": {"d1": 1}}, + {"q1": {"d1": 1.0}}, + [1], + ) + + +if __name__ == "__main__": + unittest.main()