diff --git a/src/core/errors.py b/src/core/errors.py index 3f53364..a619191 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -385,3 +385,26 @@ class InternalError(ProblemDetailError): uri = "https://openml.org/problems/internal-error" title = "Internal Server Error" _default_status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + +# ============================================================================= +# Run Errors +# ============================================================================= + + +class RunNotFoundError(ProblemDetailError): + """Raised when a run cannot be found.""" + + uri = "https://openml.org/problems/run-not-found" + title = "Run Not Found" + _default_status_code = HTTPStatus.PRECONDITION_FAILED + _default_code = 571 + + +class RunTraceNotFoundError(ProblemDetailError): + """Raised when trace data for a run cannot be found.""" + + uri = "https://openml.org/problems/run-trace-not-found" + title = "Run Trace Not Found" + _default_status_code = HTTPStatus.PRECONDITION_FAILED + _default_code = 572 diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 0000000..d0fb1f2 --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,40 @@ +"""Database queries for run-related data.""" + +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection + + +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + """Check if a run exists by ID.""" + row = await expdb.execute( + text( + """ + SELECT 1 + FROM `run` + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": id_}, + ) + return row.one_or_none() + + +async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: + """Get trace rows for a run from the trace table.""" + rows = await expdb.execute( + text( + """ + SELECT `repeat`, `fold`, `iteration`, `setup_string`, `evaluation`, `selected` + FROM `trace` + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast( + "Sequence[Row]", + rows.all(), + ) diff --git a/src/main.py b/src/main.py index 76a52ad..19a6c84 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as runs_router from routers.openml.setups import router as setup_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router @@ -70,6 +71,7 @@ def create_api() -> FastAPI: app.include_router(flows_router) app.include_router(study_router) app.include_router(setup_router) + app.include_router(runs_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 0000000..03f51e4 --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,44 @@ +"""Endpoints for run-related data.""" + +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.runs +from core.errors import RunNotFoundError, RunTraceNotFoundError +from routers.dependencies import expdb_connection +from schemas.runs import RunTrace, TraceIteration + +router = APIRouter(prefix="/runs", tags=["runs"]) + + +@router.get("/trace/{run_id}") +async def get_run_trace( + run_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> RunTrace: + """Get trace data for a run by run ID.""" + if not await database.runs.get(run_id, expdb): + msg = f"Run {run_id} not found." + raise RunNotFoundError(msg) + + trace_rows = await database.runs.get_trace(run_id, expdb) + if not trace_rows: + msg = f"No trace found for run {run_id}." + raise RunTraceNotFoundError(msg) + + return RunTrace( + run_id=run_id, + trace=[ + TraceIteration( + repeat=row.repeat, + fold=row.fold, + iteration=row.iteration, + setup_string=row.setup_string, + evaluation=row.evaluation, + selected=row.selected, + ) + for row in trace_rows + ], + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 0000000..857f492 --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,21 @@ +"""Pydantic schemas for run-related endpoints.""" + +from pydantic import BaseModel + + +class TraceIteration(BaseModel): + """A single trace iteration for a run.""" + + repeat: int + fold: int + iteration: int + setup_string: str | None + evaluation: float | None + selected: str + + +class RunTrace(BaseModel): + """Trace data for a run.""" + + run_id: int + trace: list[TraceIteration] diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 0000000..f82e6a0 --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,49 @@ +"""Tests for the GET /runs/trace/{run_id} endpoint.""" + +from http import HTTPStatus + +import httpx +import pytest + +from core.errors import RunNotFoundError, RunTraceNotFoundError + + +@pytest.mark.parametrize("run_id", [34]) +async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that trace data is returned for a run that has trace entries.""" + response = await py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.OK + body = response.json() + assert body["run_id"] == run_id + assert isinstance(body["trace"], list) + assert len(body["trace"]) > 0 + first = body["trace"][0] + assert isinstance(first["repeat"], int) + assert isinstance(first["fold"], int) + assert isinstance(first["iteration"], int) + assert first["selected"] in ("true", "false") + assert first["evaluation"] is None or isinstance(first["evaluation"], float) + + +@pytest.mark.parametrize("run_id", [24]) +async def test_get_run_trace_no_trace(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that 412 is returned for a run that exists but has no trace.""" + response = await py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + body = response.json() + assert body["code"] == "572" # RunTraceNotFoundError code + assert body["type"] == RunTraceNotFoundError.uri + assert body["title"] == RunTraceNotFoundError.title + assert body["status"] == HTTPStatus.PRECONDITION_FAILED + + +@pytest.mark.parametrize("run_id", [999999]) +async def test_get_run_trace_run_not_found(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that 412 is returned when the run does not exist.""" + response = await py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + body = response.json() + assert body["code"] == "571" # RunNotFoundError code + assert body["type"] == RunNotFoundError.uri + assert body["title"] == RunNotFoundError.title + assert body["status"] == HTTPStatus.PRECONDITION_FAILED