diff --git a/src/config.py b/src/config.py index ffc4fb89..04dd0fc9 100644 --- a/src/config.py +++ b/src/config.py @@ -54,6 +54,11 @@ def load_routing_configuration(file: Path = _config_file) -> TomlTable: return typing.cast("TomlTable", _load_configuration(file)["routing"]) +@functools.cache +def load_run_configuration(file: Path = _config_file) -> TomlTable: + return typing.cast("TomlTable", _load_configuration(file).get("run", {})) + + @functools.cache def load_database_configuration(file: Path = _config_file) -> TomlTable: configuration = _load_configuration(file) diff --git a/src/config.toml b/src/config.toml index 384067d7..b56a3e0e 100644 --- a/src/config.toml +++ b/src/config.toml @@ -37,3 +37,6 @@ database="openml" [routing] minio_url="http://minio:9000/" server_url="http://php-api:80/" + +[run] +evaluation_engine_ids = [1] diff --git a/src/database/flows.py b/src/database/flows.py index 79bb6e5b..ec04f6a6 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -73,7 +73,7 @@ async def get(id_: int, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( text( """ - SELECT *, uploadDate as upload_date + SELECT *, uploadDate as upload_date, fullName AS full_name FROM implementation WHERE id = :flow_id """, diff --git a/src/database/runs.py b/src/database/runs.py index acf7a532..be16b5e4 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Row, text +from sqlalchemy import Row, bindparam, text from sqlalchemy.ext.asyncio import AsyncConnection @@ -22,6 +22,121 @@ async def exist(id_: int, expdb: AsyncConnection) -> bool: return bool(row.one_or_none()) +async def get(run_id: int, expdb: AsyncConnection) -> Row | None: + """Fetch the core run row from the `run` table. + + Returns the row if found, or None if no run with `run_id` exists. + The `error_message` column is NULL when the run completed without errors. + """ + row = await expdb.execute( + text( + """ + SELECT `rid`, `uploader`, `setup`, `task_id`, `error_message` + FROM `run` + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return row.one_or_none() + + +async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]: + """Fetch all tags associated with a run from the `run_tag` table. + + The `id` column in `run_tag` refers to the run ID + """ + rows = await expdb.execute( + text( + """ + SELECT `tag` + FROM `run_tag` + WHERE `id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return [row.tag for row in rows.all()] + + +async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]: + """Fetch the dataset(s) used as input for a run, with name and url. + + Joins `input_data` with `dataset` to include the dataset name and ARFF URL. + """ + rows = await expdb.execute( + text( + """ + SELECT `id`.`data` AS `did`, `d`.`name`, `d`.`url` + FROM `input_data` `id` + JOIN `dataset` `d` ON `id`.`data` = `d`.`did` + WHERE `id`.`run` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast("list[Row]", rows.all()) + + +async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]: + """Fetch output files attached to a run from the `runfile` table. + + Typical entries include the description XML and predictions ARFF. + The `field` column holds the file label (e.g. "description", "predictions"). + """ + rows = await expdb.execute( + text( + """ + SELECT `file_id`, `field` + FROM `runfile` + WHERE `source` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast("list[Row]", rows.all()) + + +async def get_evaluations( + run_id: int, + expdb: AsyncConnection, + *, + evaluation_engine_ids: list[int], +) -> list[Row]: + """Fetch evaluation metric results for a run. + + Joins `evaluation` with `math_function` to resolve the metric name + (the `evaluation` table stores only a `function_id`, not the name directly). + + Filters by `evaluation_engine_id IN (...)`. The list is configurable + via `config.toml [run] evaluation_engine_ids`. + Dynamic named parameters are used for aiomysql compatibility. + """ + if not evaluation_engine_ids: + return [] + + query = text( + """ + SELECT `m`.`name`, `e`.`value`, `e`.`array_data`, NULL as `repeat`, NULL as `fold` + FROM `evaluation` `e` + JOIN `math_function` `m` ON `e`.`function_id` = `m`.`id` + WHERE `e`.`source` = :run_id + AND `e`.`evaluation_engine_id` IN :engine_ids + UNION ALL + SELECT `m`.`name`, `ef`.`value`, `ef`.`array_data`, `ef`.`repeat`, `ef`.`fold` + FROM `evaluation_fold` `ef` + JOIN `math_function` `m` ON `ef`.`function_id` = `m`.`id` + WHERE `ef`.`source` = :run_id + AND `ef`.`evaluation_engine_id` IN :engine_ids + """, + ).bindparams(bindparam("engine_ids", expanding=True)) + rows = await expdb.execute( + query, + parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids}, + ) + return cast("list[Row]", rows.all()) + + async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: """Get trace rows for a run from the trace table.""" rows = await expdb.execute( diff --git a/src/database/tasks.py b/src/database/tasks.py index e9670d26..a14a9ff2 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -48,6 +48,49 @@ async def get_task_type(task_type_id: int, expdb: AsyncConnection) -> Row | None return row.one_or_none() +async def get_task_type_name(task_id: int, expdb: AsyncConnection) -> str | None: + """Fetch the human-readable task type name for the task associated with a run. + + Joins `task` and `task_type` on `ttid` to resolve the name + (e.g. "Supervised Classification"). + """ + row = await expdb.execute( + text( + """ + SELECT `tt`.`name` + FROM `task` `t` + JOIN `task_type` `tt` ON `t`.`ttid` = `tt`.`ttid` + WHERE `t`.`task_id` = :task_id + """, + ), + parameters={"task_id": task_id}, + ) + result = row.one_or_none() + return result.name if result else None + + +async def get_task_evaluation_measure(task_id: int, expdb: AsyncConnection) -> str | None: + """Fetch the evaluation measure configured for a task, if any. + + Queries `task_inputs` for the row where `input = 'evaluation_measures'`. + Returns None (not an empty string) when no such row exists, so callers + can treat a falsy result uniformly. + """ + row = await expdb.execute( + text( + """ + SELECT `value` + FROM `task_inputs` + WHERE `task_id` = :task_id + AND `input` = 'evaluation_measures' + """, + ), + parameters={"task_id": task_id}, + ) + result = row.one_or_none() + return result.value if result else None + + async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]: rows = await expdb.execute( text( diff --git a/src/database/users.py b/src/database/users.py index 0b09fb0d..c2be9143 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -26,19 +26,45 @@ class UserGroup(IntEnum): READ_ONLY = (3,) -async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> int | None: - row = await connection.execute( - text( - """ - SELECT * - FROM users - WHERE session_hash = :api_key - """, - ), - parameters={"api_key": api_key}, +async def get_user( + *, + connection: AsyncConnection, + api_key: APIKey | None = None, + user_id: int | None = None, +) -> "User | None": + """Fetch the full user by either api_key or user_id.""" + if (api_key is None) == (user_id is None): + msg = "Exactly one of api_key or user_id must be provided." + raise ValueError(msg) + + if api_key is not None: + query = """ + SELECT id, first_name, last_name + FROM users + WHERE session_hash = :api_key + LIMIT 1 + """ + else: + query = """ + SELECT id, first_name, last_name + FROM users + WHERE id = :user_id + LIMIT 1 + """ + + result = await connection.execute( + text(query), + parameters={"api_key": api_key, "user_id": user_id}, ) - user = row.one_or_none() - return user.id if user else None + row = result.one_or_none() + if row: + return User( + user_id=row.id, + first_name=row.first_name, + last_name=row.last_name, + _database=connection, + ) + return None async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[int]: @@ -60,12 +86,25 @@ async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> l class User: user_id: int _database: AsyncConnection + first_name: str = "" + last_name: str = "" _groups: list[UserGroup] | None = None + @property + def full_name(self) -> str: + """Return the combined first and last name.""" + return " ".join(part for part in [self.first_name, self.last_name] if part) + @classmethod async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: - if (user_id := await get_user_id_for(api_key=api_key, connection=user_db)) is not None: - return cls(user_id, _database=user_db) + user = await get_user(api_key=api_key, connection=user_db) + if user is not None: + return cls( + user_id=user.user_id, + first_name=user.first_name, + last_name=user.last_name, + _database=user_db, + ) return None async def get_groups(self) -> list[UserGroup]: diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 37a7cecf..340c8d19 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,14 +1,33 @@ """Endpoints for run-related data.""" -from typing import Annotated +import asyncio +from dataclasses import dataclass +from typing import TYPE_CHECKING, Annotated, Any, cast from fastapi import APIRouter, Depends + +if TYPE_CHECKING: + from sqlalchemy import Row from sqlalchemy.ext.asyncio import AsyncConnection +import config +import database.flows import database.runs +import database.setups +import database.tasks +import database.users from core.errors import RunNotFoundError, RunTraceNotFoundError -from routers.dependencies import expdb_connection -from schemas.runs import RunTrace, TraceIteration +from routers.dependencies import expdb_connection, userdb_connection +from schemas.runs import ( + EvaluationScore, + InputDataset, + OutputData, + OutputFile, + ParameterSetting, + Run, + RunTrace, + TraceIteration, +) router = APIRouter(prefix="/run", tags=["run"]) @@ -42,3 +61,159 @@ async def get_run_trace( for row in trace_rows ], ) + + +@dataclass +class RunContext: + """Helper context to store concurrently fetched run dependencies.""" + + uploader_name: str | None + tags: list[str] + input_data_rows: list["Row"] + output_file_rows: list["Row"] + evaluation_rows: list["Row"] + task_type: str | None + task_evaluation_measure: str | None + setup: "Row | None" + parameter_rows: list["Row"] + + +async def _load_run_context( + run: "Row", + run_id: int, + expdb: AsyncConnection, + userdb: AsyncConnection, + engine_ids: list[int], +) -> RunContext: + ( + uploader_user, + tags, + input_data_rows, + output_file_rows, + evaluation_rows, + task_type, + task_evaluation_measure, + setup, + parameter_rows, + ) = cast( + "tuple[Any, list[str], list[Row], list[Row], list[Row], str | None, str |" + "None, Row | None, list[Row]]", + await asyncio.gather( + database.users.get_user(user_id=run.uploader, connection=userdb), + database.runs.get_tags(run_id, expdb), + database.runs.get_input_data(run_id, expdb), + database.runs.get_output_files(run_id, expdb), + database.runs.get_evaluations(run_id, expdb, evaluation_engine_ids=engine_ids), + database.tasks.get_task_type_name(run.task_id, expdb), + database.tasks.get_task_evaluation_measure(run.task_id, expdb), + database.setups.get(run.setup, expdb), + database.setups.get_parameters(run.setup, expdb), + ), + ) + return RunContext( + uploader_name=uploader_user.full_name if uploader_user else None, + tags=tags, + input_data_rows=input_data_rows, + output_file_rows=output_file_rows, + evaluation_rows=evaluation_rows, + task_type=task_type, + task_evaluation_measure=task_evaluation_measure, + setup=setup, + parameter_rows=parameter_rows, + ) + + +def _build_parameter_settings(parameter_rows: list["Row"]) -> list[ParameterSetting]: + return [ + ParameterSetting( + name=p["name"], + value=p["value"], + component=p["flow_id"], + ) + for p in parameter_rows + ] + + +def _build_input_datasets(rows: list["Row"]) -> list[InputDataset]: + return [InputDataset(did=row.did, name=row.name, url=row.url) for row in rows] + + +def _build_output_files(rows: list["Row"]) -> list[OutputFile]: + """Build output files list. + + Note: the PHP response includes a deprecated `did` field hardcoded to "-1" + for each file. This implementation omits it entirely. + """ + return [OutputFile(file_id=row.file_id, name=row.field) for row in rows] + + +def _build_evaluations(rows: list["Row"]) -> list[EvaluationScore]: + def _normalise_value(v: object) -> object: + if isinstance(v, (int, float)): + return int(v) if float(v).is_integer() else v + if isinstance(v, str): + try: + f = float(v) + return int(f) if f.is_integer() else f + except ValueError: + return None + return None + + return [ + EvaluationScore( + name=row.name, + value=_normalise_value(row.value), + array_data=row.array_data, + repeat=getattr(row, "repeat", None), + fold=getattr(row, "fold", None), + ) + for row in rows + ] + + +@router.get("/{run_id}", response_model_exclude_none=True) +async def get_run( + run_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], + userdb: Annotated[AsyncConnection, Depends(userdb_connection)], +) -> Run: + """Get full metadata for a run by ID. + + No authentication or visibility check is performed — all runs are + publicly accessible. + """ + run = await database.runs.get(run_id, expdb) + if run is None: + msg = f"Run {run_id} not found." + raise RunNotFoundError(msg, code=236) + + engine_ids: list[int] = config.load_run_configuration().get("evaluation_engine_ids", [1]) + ctx = await _load_run_context(run, run_id, expdb, userdb, engine_ids) + + flow = await database.flows.get(ctx.setup.implementation_id, expdb) if ctx.setup else None + + parameter_settings = _build_parameter_settings(ctx.parameter_rows) + input_datasets = _build_input_datasets(ctx.input_data_rows) + output_files = _build_output_files(ctx.output_file_rows) + evaluations = _build_evaluations(ctx.evaluation_rows) + + normalised_measure = ctx.task_evaluation_measure or None + error_messages = [run.error_message] if run.error_message else [] + + return Run( + run_id=run_id, + uploader=run.uploader, + uploader_name=ctx.uploader_name, + task_id=run.task_id, + task_type=ctx.task_type, + task_evaluation_measure=normalised_measure, + flow_id=ctx.setup.implementation_id if ctx.setup else None, + flow_name=flow.full_name if flow else None, + setup_id=run.setup, + setup_string=ctx.setup.setup_string if ctx.setup else None, + parameter_setting=parameter_settings, + error_message=error_messages, + tag=ctx.tags, + input_data=input_datasets, + output_data=OutputData(file=output_files, evaluation=evaluations), + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py index 857f4921..15d887ac 100644 --- a/src/schemas/runs.py +++ b/src/schemas/runs.py @@ -1,6 +1,6 @@ """Pydantic schemas for run-related endpoints.""" -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field class TraceIteration(BaseModel): @@ -19,3 +19,90 @@ class RunTrace(BaseModel): run_id: int trace: list[TraceIteration] + + +class ParameterSetting(BaseModel): + """A single hyperparameter value used in a run's setup. + + `component` is the `implementation_id` of the flow that defines this + parameter — useful when a setup spans multiple sub-flows (components). + `value` is None when the parameter was not explicitly set (uses default). + """ + + name: str + value: str | None + component: int # = input.implementation_id (flow_id of the owning component) + + +class InputDataset(BaseModel): + """A dataset used as input for a run. + + Sourced from `input_data` JOIN `dataset`. `name` and `url` are fetched + from the `dataset` table and match the values PHP returns. + """ + + did: int + name: str + url: str # ARFF download URL stored in dataset.url + + +class OutputFile(BaseModel): + """An output file produced by or attached to a run. + + Sourced from the `runfile` table. `name` is the file label + (e.g. "description", "predictions"). + + Note: the legacy PHP response included a `did` field hardcoded to "-1" + for every entry here. It originates from a deprecated idea that run outputs + could create new datasets. It is intentionally omitted in this implementation. + """ + + file_id: int + name: str # label as stored in runfile.field, e.g. "description", "predictions" + + +class EvaluationScore(BaseModel): + """An evaluation metric score for a run. + + Sourced from a JOIN of `evaluation` and `math_function`. + `array_data` holds per-fold/per-class breakdowns when available; + `value` holds the aggregate scalar. + `repeat` and `fold` are present for per-fold metrics. + """ + + name: str + value: float | int | None # whole numbers returned as int to match PHP + array_data: str | None + repeat: int | None = None + fold: int | None = None + + +class OutputData(BaseModel): + """Wrapper for output files and evaluations.""" + + file: list[OutputFile] + evaluation: list[EvaluationScore] + + +class Run(BaseModel): + """Full metadata response for a single OpenML run.""" + + model_config = ConfigDict(populate_by_name=True) + + run_id: int + uploader: int # user ID of the uploader + uploader_name: str | None + task_id: int + task_type: str | None # e.g. "Supervised Classification" + task_evaluation_measure: str | None + flow_id: int | None = None + flow_name: str | None + setup_id: int | None = None + setup_string: str | None # human-readable description of the setup + parameter_setting: list[ParameterSetting] + # Serialized as "error" in JSON to match the PHP response key. + # At the Python level we keep the name error_message for clarity. + error_message: list[str] = Field(serialization_alias="error") + tag: list[str] + input_data: list[InputDataset] + output_data: OutputData diff --git a/tests/database/runs_test.py b/tests/database/runs_test.py new file mode 100644 index 00000000..cabdbd86 --- /dev/null +++ b/tests/database/runs_test.py @@ -0,0 +1,107 @@ +"""Tests for database layer of runs.""" + +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.runs +import database.tasks +import database.users + +_RUN_ID = 24 +_MISSING_RUN_ID = 999_999_999 +_MISSING_USER_ID = 999_999_999 +_RUN_UPLOADER_ID = 1159 +_RUN_TASK_ID = 115 +_RUN_SETUP_ID = 2 +_RUN_DATASET_ID = 20 +_DESCRIPTION_FILE_ID = 182 +_PREDICTIONS_FILE_ID = 183 + + +async def test_db_get_run_exists(expdb_test: AsyncConnection) -> None: + """database.runs.get returns a row for run 24.""" + row = await database.runs.get(_RUN_ID, expdb_test) + assert row is not None + assert row.rid == _RUN_ID + assert row.uploader == _RUN_UPLOADER_ID + assert row.task_id == _RUN_TASK_ID + assert row.setup == _RUN_SETUP_ID + assert row.error_message is None # no error for this run + + +async def test_db_get_run_missing(expdb_test: AsyncConnection) -> None: + """database.runs.get returns None for a non-existent run.""" + row = await database.runs.get(_MISSING_RUN_ID, expdb_test) + assert row is None + + +async def test_db_exist_true(expdb_test: AsyncConnection) -> None: + """database.runs.exist returns True for run 24.""" + assert await database.runs.exist(_RUN_ID, expdb_test) is True + + +async def test_db_exist_false(expdb_test: AsyncConnection) -> None: + """database.runs.exist returns False for a missing run.""" + assert await database.runs.exist(_MISSING_RUN_ID, expdb_test) is False + + +async def test_db_get_tags(expdb_test: AsyncConnection) -> None: + """database.runs.get_tags returns expected tags for run 24.""" + tags = await database.runs.get_tags(_RUN_ID, expdb_test) + assert isinstance(tags, list) + assert "openml-python" in tags + + +async def test_db_get_input_data(expdb_test: AsyncConnection) -> None: + """database.runs.get_input_data returns did=20 (diabetes) for run 24.""" + rows = await database.runs.get_input_data(_RUN_ID, expdb_test) + assert len(rows) >= 1 + dids = [r.did for r in rows] + assert _RUN_DATASET_ID in dids + + +async def test_db_get_output_files(expdb_test: AsyncConnection) -> None: + """database.runs.get_output_files returns description and predictions files.""" + rows = await database.runs.get_output_files(_RUN_ID, expdb_test) + file_map = {r.field: r.file_id for r in rows} + assert file_map.get("description") == _DESCRIPTION_FILE_ID + assert file_map.get("predictions") == _PREDICTIONS_FILE_ID + + +async def test_db_get_evaluations(expdb_test: AsyncConnection) -> None: + """database.runs.get_evaluations returns metrics including area_under_roc_curve.""" + rows = await database.runs.get_evaluations(_RUN_ID, expdb_test, evaluation_engine_ids=[1]) + assert len(rows) > 0 + names = {r.name for r in rows} + assert "area_under_roc_curve" in names + + +async def test_db_get_evaluations_empty_engine_list(expdb_test: AsyncConnection) -> None: + """get_evaluations with no engine IDs returns an empty list (not an error).""" + rows = await database.runs.get_evaluations(_RUN_ID, expdb_test, evaluation_engine_ids=[]) + assert rows == [] + + +async def test_db_get_task_type(expdb_test: AsyncConnection) -> None: + """database.runs.get_task_type returns 'Supervised Classification' for task 115.""" + task_type = await database.tasks.get_task_type_name(_RUN_TASK_ID, expdb_test) + assert task_type == "Supervised Classification" + + +async def test_db_get_task_evaluation_measure_missing(expdb_test: AsyncConnection) -> None: + """get_task_evaluation_measure returns None (not '') when absent.""" + measure = await database.tasks.get_task_evaluation_measure(_RUN_TASK_ID, expdb_test) + assert measure is None + + +async def test_db_get_uploader_name(user_test: AsyncConnection) -> None: + """database.runs.get_uploader_name returns 'Cynthia Glover' for user 1159.""" + user = await database.users.get_user(user_id=_RUN_UPLOADER_ID, connection=user_test) + assert user is not None + assert user.full_name == "Cynthia Glover" + assert user.user_id == _RUN_UPLOADER_ID + + +async def test_db_get_uploader_name_missing(user_test: AsyncConnection) -> None: + """get_uploader_name returns None for a non-existent user.""" + user = await database.users.get_user(user_id=_MISSING_USER_ID, connection=user_test) + assert user is None diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py new file mode 100644 index 00000000..3a1aa682 --- /dev/null +++ b/tests/routers/openml/runs_get_test.py @@ -0,0 +1,309 @@ +"""Tests for GET /run/{id} endpoint""" + +import asyncio +from http import HTTPStatus +from typing import Any, NamedTuple +from unittest.mock import AsyncMock, patch + +import deepdiff +import httpx +import pytest + +from core.conversions import nested_num_to_str, nested_remove_single_element_list +from routers.openml.runs import _build_evaluations + +# ── Fixtures assume run 24 exists in the test DB (confirmed in research) ── +_RUN_ID = 24 +_MISSING_RUN_ID = 999_999_999 + +_RUN_NOT_FOUND_CODE = "236" + +_RUN_UPLOADER_ID = 1159 +_RUN_TASK_ID = 115 +_RUN_FLOW_ID = 19 +_RUN_SETUP_ID = 2 +_RUN_DATASET_ID = 20 +_DESCRIPTION_FILE_ID = 182 +_PREDICTIONS_FILE_ID = 183 + + +# ════════════════════════════════════════════════════════════════════ +# Happy-path API tests (use py_api httpx client) +# ════════════════════════════════════════════════════════════════════ + + +async def test_get_run_status_ok(py_api: httpx.AsyncClient) -> None: + """GET /run/{id} returns 200 for a known run.""" + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + data = response.json() + assert isinstance(data, dict) + assert "run_id" in data + assert data["run_id"] == _RUN_ID + + +async def test_get_run_happy_path(py_api: httpx.AsyncClient) -> None: # noqa: PLR0915 + """Comprehensive check of run 24.""" + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + run = response.json() + + # 1. Top-level shape + expected_keys = { + "run_id", + "uploader", + "uploader_name", + "task_id", + "task_type", + "flow_id", + "flow_name", + "setup_id", + "setup_string", + "parameter_setting", + "error", + "tag", + "input_data", + "output_data", + } + assert expected_keys <= run.keys(), f"Missing keys: {expected_keys - run.keys()}" + + # 2. Known core values + assert run["run_id"] == _RUN_ID + assert run["uploader"] == _RUN_UPLOADER_ID + assert run["uploader_name"] == "Cynthia Glover" + assert run["task_id"] == _RUN_TASK_ID + assert run["task_type"] == "Supervised Classification" + assert run["flow_id"] == _RUN_FLOW_ID + assert run["setup_id"] == _RUN_SETUP_ID + assert "Python_3.10.5" in run["setup_string"] + assert "openml-python" in run["tag"] + assert run["error"] == [] + + # 3. Input Data + datasets = run["input_data"] + assert isinstance(datasets, list) + assert len(datasets) > 0 + dataset = datasets[0] + assert "did" in dataset + assert "name" in dataset + assert "url" in dataset + assert dataset["did"] == _RUN_DATASET_ID + assert dataset["name"] == "diabetes" + + # 4. Output Data Shape + assert "file" in run["output_data"] + assert "evaluation" in run["output_data"] + files = run["output_data"]["file"] + assert isinstance(files, list) + assert len(files) > 0 + file_ = files[0] + assert "file_id" in file_ + assert "name" in file_ + assert "did" not in file_ + + evaluations = run["output_data"]["evaluation"] + assert isinstance(evaluations, list) + assert len(evaluations) > 0 + eval_ = evaluations[0] + assert "name" in eval_ + assert "value" in eval_ + + # 5. Known output files & evaluations + file_map = {f["name"]: f["file_id"] for f in files} + assert file_map.get("description") == _DESCRIPTION_FILE_ID + assert file_map.get("predictions") == _PREDICTIONS_FILE_ID + + eval_names = {e["name"] for e in evaluations} + assert "area_under_roc_curve" in eval_names + + for ev in evaluations: + if ev["value"] is not None and isinstance(ev["value"], float): + assert ev["value"] != int(ev["value"]), "Expected whole-number floats to be int" + + # 6. Parameter settings + params = run["parameter_setting"] + assert isinstance(params, list) + for p in params: + assert "name" in p + assert "value" in p + assert "component" in p + assert isinstance(p["component"], int) + + +async def test_get_run_non_empty_error(py_api: httpx.AsyncClient) -> None: + """A run with a non-null error_message is serialized as a single-item error list.""" + + # Since the test database does not have a run with an error, we mock the DB fetch + class MockRunRow(NamedTuple): + rid: int + uploader: int + setup: int + task_id: int + error_message: str + + mock_row = MockRunRow( + rid=_RUN_ID, + uploader=_RUN_UPLOADER_ID, + setup=_RUN_SETUP_ID, + task_id=_RUN_TASK_ID, + error_message="Some error from the backend", + ) + + with patch("routers.openml.runs.database.runs.get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_row + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + + run = response.json() + assert run["error"] == ["Some error from the backend"] + + +async def test_get_run_not_found(py_api: httpx.AsyncClient) -> None: + """Non-existent run returns 404 with error code 236 (PHP compat).""" + response = await py_api.get(f"/run/{_MISSING_RUN_ID}") + assert response.status_code == HTTPStatus.NOT_FOUND + error = response.json() + # Verify PHP-compat error code + assert str(error.get("code")) == _RUN_NOT_FOUND_CODE + + +async def test_task_evaluation_measure_omitted_when_null(py_api: httpx.AsyncClient) -> None: + """task_evaluation_measure is not present in JSON when no measure is configured.""" + # Run 24 is known to not have a task evaluation measure (verified in db test) + response = await py_api.get(f"/run/{_RUN_ID}") + run = response.json() + assert "task_evaluation_measure" not in run + + +async def test_task_evaluation_measure_present_when_configured( + py_api: httpx.AsyncClient, +) -> None: + """task_evaluation_measure is present and matches DB when a measure is configured.""" + # Since the test database does not have a run with an evaluation measure, we mock the DB fetch + with patch( + "routers.openml.runs.database.tasks.get_task_evaluation_measure", new_callable=AsyncMock + ) as mock_get_measure: + mock_get_measure.return_value = "predictive_accuracy" + response = await py_api.get(f"/run/{_RUN_ID}") + assert response.status_code == HTTPStatus.OK + + run = response.json() + assert "task_evaluation_measure" in run + assert run["task_evaluation_measure"] == "predictive_accuracy" + + +# ════════════════════════════════════════════════════════════════════ +# Migration tests (Python API vs PHP API parity) +# ════════════════════════════════════════════════════════════════════ + +# Regex paths excluded from DeepDiff — only genuinely untestable fields. +_EXCLUDE_PATHS = [ + # [1] PHP hardcodes did="-1" in output_data.file; Python omits it (deprecated). + r"root\['run'\]\['output_data'\]\['file'\]\[\d+\]\['did'\]", + # [2] PHP generates output file URLs from its own server_url config. + # Python does not yet have a file download endpoint, so URLs differ by design. + r"root\['run'\]\['output_data'\]\['file'\]\[\d+\]\['url'\]", +] + + +def _normalize_py_run(py_run: dict[str, Any]) -> dict[str, Any]: + """Normalize a Python run response to match the PHP response format.""" + run = py_run.copy() + run = nested_remove_single_element_list(run) + + if "input_data" in run: + run["input_data"] = {"dataset": run["input_data"]} + + run = nested_num_to_str(run) + + return {"run": run} + + +# Run IDs to test, including a non-existent one to verify error parity. +_RUN_IDS = [*range(24, 35), 999_999_999] + + +@pytest.mark.parametrize("run_id", _RUN_IDS) +async def test_get_run_equal( + run_id: int, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + """Python and PHP run responses are equivalent after normalization.""" + py_response, php_response = await asyncio.gather( + py_api.get(f"/run/{run_id}"), + php_api.get(f"/run/{run_id}"), + ) + + # Error case: run does not exist. + # PHP returns 412 PRECONDITION_FAILED; Python returns 404 NOT_FOUND. + if php_response.status_code != HTTPStatus.OK: + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.status_code == HTTPStatus.NOT_FOUND + php_code = php_response.json().get("error", {}).get("code") + py_code = py_response.json()["code"] + assert py_code == php_code + return + + assert py_response.status_code == HTTPStatus.OK + + py_normalized = _normalize_py_run(py_response.json()) + php_json = php_response.json() + + # PHP provides evaluation entries natively for each fold (with `repeat` and `fold` keys) + # as well as an aggregate entry. Python now supports returning these exact entries. + + # PHP sometimes includes empty `error` property instead of an empty list when no error occurred + # DeepDiff takes care of it automatically because we didn't see error diffs. + + differences = deepdiff.diff.DeepDiff( + py_normalized, + php_json, + ignore_order=True, + ignore_numeric_type_changes=True, + exclude_regex_paths=_EXCLUDE_PATHS, + ) + assert not differences, f"Differences for run {run_id}: {differences}" + + +@pytest.mark.parametrize( + ("input_value", "expected_value", "repeat", "fold"), + [ + (1.0, 1, None, None), + ("2.0", 2, None, None), + ("1.5", 1.5, None, None), + ("not_a_number", None, None, None), + (["list"], None, None, None), + (0.95, 0.95, 0, 2), + ], +) +def test_build_evaluations( + input_value: object, + expected_value: object, + repeat: int | None, + fold: int | None, +) -> None: + """_build_evaluations normalizes values correctly and maps repeat/fold.""" + + class MockRow: + def __init__( + self, + name: str, + value: object, + array_data: str | None = None, + repeat: int | None = None, + fold: int | None = None, + ) -> None: + self.name = name + self.value = value + self.array_data = array_data + self.repeat = repeat + self.fold = fold + + rows = [MockRow("test_metric", input_value, repeat=repeat, fold=fold)] + evals = _build_evaluations(rows) + + assert len(evals) == 1 + assert evals[0].value == expected_value + assert evals[0].repeat == repeat + assert evals[0].fold == fold