From dd83feaac771af08a9d5c2cd4440d67b088b5ae3 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Sat, 14 Mar 2026 10:38:40 +0530 Subject: [PATCH 1/2] Add GET /runs/trace/{run_id} endpoint --- src/core/errors.py | 23 ++++++++++++++++ src/database/runs.py | 40 ++++++++++++++++++++++++++++ src/main.py | 2 ++ src/routers/openml/runs.py | 44 +++++++++++++++++++++++++++++++ src/schemas/runs.py | 21 +++++++++++++++ tests/routers/openml/runs_test.py | 42 +++++++++++++++++++++++++++++ 6 files changed, 172 insertions(+) create mode 100644 src/database/runs.py create mode 100644 src/routers/openml/runs.py create mode 100644 src/schemas/runs.py create mode 100644 tests/routers/openml/runs_test.py diff --git a/src/core/errors.py b/src/core/errors.py index 3f53364..a619191 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -385,3 +385,26 @@ class InternalError(ProblemDetailError): uri = "https://openml.org/problems/internal-error" title = "Internal Server Error" _default_status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + +# ============================================================================= +# Run Errors +# ============================================================================= + + +class RunNotFoundError(ProblemDetailError): + """Raised when a run cannot be found.""" + + uri = "https://openml.org/problems/run-not-found" + title = "Run Not Found" + _default_status_code = HTTPStatus.PRECONDITION_FAILED + _default_code = 571 + + +class RunTraceNotFoundError(ProblemDetailError): + """Raised when trace data for a run cannot be found.""" + + uri = "https://openml.org/problems/run-trace-not-found" + title = "Run Trace Not Found" + _default_status_code = HTTPStatus.PRECONDITION_FAILED + _default_code = 572 diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 0000000..1c46881 --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,40 @@ +"""Database queries for run-related data.""" + +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection + + +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + """Get a run by ID from the run table.""" + row = await expdb.execute( + text( + """ + SELECT `rid` + FROM `run` + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": id_}, + ) + return row.one_or_none() + + +async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: + """Get trace rows for a run from the trace table.""" + rows = await expdb.execute( + text( + """ + SELECT `repeat`, `fold`, `iteration`, `setup_string`, `evaluation`, `selected` + FROM `trace` + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast( + "Sequence[Row]", + rows.all(), + ) diff --git a/src/main.py b/src/main.py index 76a52ad..19a6c84 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as runs_router from routers.openml.setups import router as setup_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router @@ -70,6 +71,7 @@ def create_api() -> FastAPI: app.include_router(flows_router) app.include_router(study_router) app.include_router(setup_router) + app.include_router(runs_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 0000000..d67ac60 --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,44 @@ +"""Endpoints for run-related data.""" + +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.runs +from core.errors import RunNotFoundError, RunTraceNotFoundError +from routers.dependencies import expdb_connection +from schemas.runs import RunTrace, TraceIteration + +router = APIRouter(prefix="/runs", tags=["runs"]) + + +@router.get("/trace/{run_id}") +async def get_run_trace( + run_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> RunTrace: + """Get trace data for a run by run ID.""" + if not await database.runs.get(run_id, expdb): + msg = f"Run {run_id} not found." + raise RunNotFoundError(msg) + + trace_rows = await database.runs.get_trace(run_id, expdb) + if not trace_rows: + msg = f"No trace found for run {run_id}." + raise RunTraceNotFoundError(msg) + + return RunTrace( + run_id=run_id, + trace=[ + TraceIteration( + repeat=row.repeat, + fold=row.fold, + iteration=row.iteration, + setup_string=row.setup_string, + evaluation=row.evaluation, + selected=row.selected == "true", + ) + for row in trace_rows + ], + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 0000000..db5f25a --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,21 @@ +"""Pydantic schemas for run-related endpoints.""" + +from pydantic import BaseModel, Field + + +class TraceIteration(BaseModel): + """A single trace iteration for a run.""" + + repeat: int + fold: int + iteration: int + setup_string: str | None + evaluation: float | None + selected: bool + + +class RunTrace(BaseModel): + """Trace data for a run.""" + + run_id: int = Field(serialization_alias="run_id") + trace: list[TraceIteration] diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 0000000..dd5dca2 --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,42 @@ +"""Tests for the GET /runs/trace/{run_id} endpoint.""" + +from http import HTTPStatus + +import httpx +import pytest + + +@pytest.mark.parametrize("run_id", [34]) +async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that trace data is returned for a run that has trace entries.""" + response = await py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.OK + body = response.json() + assert body["run_id"] == run_id + assert isinstance(body["trace"], list) + assert len(body["trace"]) > 0 + first = body["trace"][0] + assert "repeat" in first + assert "fold" in first + assert "iteration" in first + assert "setup_string" in first + assert "evaluation" in first + assert "selected" in first + + +@pytest.mark.parametrize("run_id", [24]) +async def test_get_run_trace_no_trace(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that 412 is returned for a run that exists but has no trace.""" + response = await py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + body = response.json() + assert body["code"] == "572" + + +@pytest.mark.parametrize("run_id", [999999]) +async def test_get_run_trace_run_not_found(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that 412 is returned when the run does not exist.""" + response = await py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + body = response.json() + assert body["code"] == "571" From 966ba5217ba10f62fbe2cd5cb4cdb960ce23c8a2 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Sat, 14 Mar 2026 11:44:01 +0530 Subject: [PATCH 2/2] Address bot review feedback --- src/database/runs.py | 4 ++-- src/routers/openml/runs.py | 2 +- src/schemas/runs.py | 6 +++--- tests/routers/openml/runs_test.py | 23 +++++++++++++++-------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/database/runs.py b/src/database/runs.py index 1c46881..d0fb1f2 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -8,11 +8,11 @@ async def get(id_: int, expdb: AsyncConnection) -> Row | None: - """Get a run by ID from the run table.""" + """Check if a run exists by ID.""" row = await expdb.execute( text( """ - SELECT `rid` + SELECT 1 FROM `run` WHERE `rid` = :run_id """, diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index d67ac60..03f51e4 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -37,7 +37,7 @@ async def get_run_trace( iteration=row.iteration, setup_string=row.setup_string, evaluation=row.evaluation, - selected=row.selected == "true", + selected=row.selected, ) for row in trace_rows ], diff --git a/src/schemas/runs.py b/src/schemas/runs.py index db5f25a..857f492 100644 --- a/src/schemas/runs.py +++ b/src/schemas/runs.py @@ -1,6 +1,6 @@ """Pydantic schemas for run-related endpoints.""" -from pydantic import BaseModel, Field +from pydantic import BaseModel class TraceIteration(BaseModel): @@ -11,11 +11,11 @@ class TraceIteration(BaseModel): iteration: int setup_string: str | None evaluation: float | None - selected: bool + selected: str class RunTrace(BaseModel): """Trace data for a run.""" - run_id: int = Field(serialization_alias="run_id") + run_id: int trace: list[TraceIteration] diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py index dd5dca2..f82e6a0 100644 --- a/tests/routers/openml/runs_test.py +++ b/tests/routers/openml/runs_test.py @@ -5,6 +5,8 @@ import httpx import pytest +from core.errors import RunNotFoundError, RunTraceNotFoundError + @pytest.mark.parametrize("run_id", [34]) async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None: @@ -16,12 +18,11 @@ async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> assert isinstance(body["trace"], list) assert len(body["trace"]) > 0 first = body["trace"][0] - assert "repeat" in first - assert "fold" in first - assert "iteration" in first - assert "setup_string" in first - assert "evaluation" in first - assert "selected" in first + assert isinstance(first["repeat"], int) + assert isinstance(first["fold"], int) + assert isinstance(first["iteration"], int) + assert first["selected"] in ("true", "false") + assert first["evaluation"] is None or isinstance(first["evaluation"], float) @pytest.mark.parametrize("run_id", [24]) @@ -30,7 +31,10 @@ async def test_get_run_trace_no_trace(run_id: int, py_api: httpx.AsyncClient) -> response = await py_api.get(f"/runs/trace/{run_id}") assert response.status_code == HTTPStatus.PRECONDITION_FAILED body = response.json() - assert body["code"] == "572" + assert body["code"] == "572" # RunTraceNotFoundError code + assert body["type"] == RunTraceNotFoundError.uri + assert body["title"] == RunTraceNotFoundError.title + assert body["status"] == HTTPStatus.PRECONDITION_FAILED @pytest.mark.parametrize("run_id", [999999]) @@ -39,4 +43,7 @@ async def test_get_run_trace_run_not_found(run_id: int, py_api: httpx.AsyncClien response = await py_api.get(f"/runs/trace/{run_id}") assert response.status_code == HTTPStatus.PRECONDITION_FAILED body = response.json() - assert body["code"] == "571" + assert body["code"] == "571" # RunNotFoundError code + assert body["type"] == RunNotFoundError.uri + assert body["title"] == RunNotFoundError.title + assert body["status"] == HTTPStatus.PRECONDITION_FAILED