Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,26 @@ class InternalError(ProblemDetailError):
uri = "https://openml.org/problems/internal-error"
title = "Internal Server Error"
_default_status_code = HTTPStatus.INTERNAL_SERVER_ERROR


# =============================================================================
# Run Errors
# =============================================================================


class RunNotFoundError(ProblemDetailError):
"""Raised when a run cannot be found."""

uri = "https://openml.org/problems/run-not-found"
title = "Run Not Found"
_default_status_code = HTTPStatus.PRECONDITION_FAILED
_default_code = 571


class RunTraceNotFoundError(ProblemDetailError):
"""Raised when trace data for a run cannot be found."""

uri = "https://openml.org/problems/run-trace-not-found"
title = "Run Trace Not Found"
_default_status_code = HTTPStatus.PRECONDITION_FAILED
_default_code = 572
40 changes: 40 additions & 0 deletions src/database/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Database queries for run-related data."""

from collections.abc import Sequence
from typing import cast

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
"""Check if a run exists by ID."""
row = await expdb.execute(
text(
"""
SELECT 1
FROM `run`
WHERE `rid` = :run_id
""",
),
parameters={"run_id": id_},
)
return row.one_or_none()


async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
text(
"""
SELECT `repeat`, `fold`, `iteration`, `setup_string`, `evaluation`, `selected`
FROM `trace`
WHERE `run_id` = :run_id
""",
),
parameters={"run_id": run_id},
)
return cast(
"Sequence[Row]",
rows.all(),
)
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from routers.openml.evaluations import router as evaluationmeasures_router
from routers.openml.flows import router as flows_router
from routers.openml.qualities import router as qualities_router
from routers.openml.runs import router as runs_router
from routers.openml.setups import router as setup_router
from routers.openml.study import router as study_router
from routers.openml.tasks import router as task_router
Expand Down Expand Up @@ -70,6 +71,7 @@ def create_api() -> FastAPI:
app.include_router(flows_router)
app.include_router(study_router)
app.include_router(setup_router)
app.include_router(runs_router)
return app


Expand Down
44 changes: 44 additions & 0 deletions src/routers/openml/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Endpoints for run-related data."""

from typing import Annotated

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncConnection

import database.runs
from core.errors import RunNotFoundError, RunTraceNotFoundError
from routers.dependencies import expdb_connection
from schemas.runs import RunTrace, TraceIteration

router = APIRouter(prefix="/runs", tags=["runs"])


@router.get("/trace/{run_id}")
async def get_run_trace(
run_id: int,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> RunTrace:
"""Get trace data for a run by run ID."""
if not await database.runs.get(run_id, expdb):
msg = f"Run {run_id} not found."
raise RunNotFoundError(msg)

trace_rows = await database.runs.get_trace(run_id, expdb)
if not trace_rows:
msg = f"No trace found for run {run_id}."
raise RunTraceNotFoundError(msg)

return RunTrace(
run_id=run_id,
trace=[
TraceIteration(
repeat=row.repeat,
fold=row.fold,
iteration=row.iteration,
setup_string=row.setup_string,
evaluation=row.evaluation,
selected=row.selected,
)
for row in trace_rows
],
)
21 changes: 21 additions & 0 deletions src/schemas/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Pydantic schemas for run-related endpoints."""

from pydantic import BaseModel


class TraceIteration(BaseModel):
"""A single trace iteration for a run."""

repeat: int
fold: int
iteration: int
setup_string: str | None
evaluation: float | None
selected: str


class RunTrace(BaseModel):
"""Trace data for a run."""

run_id: int
trace: list[TraceIteration]
49 changes: 49 additions & 0 deletions tests/routers/openml/runs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for the GET /runs/trace/{run_id} endpoint."""

from http import HTTPStatus

import httpx
import pytest

from core.errors import RunNotFoundError, RunTraceNotFoundError


@pytest.mark.parametrize("run_id", [34])
async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None:
"""Test that trace data is returned for a run that has trace entries."""
response = await py_api.get(f"/runs/trace/{run_id}")
assert response.status_code == HTTPStatus.OK
body = response.json()
assert body["run_id"] == run_id
assert isinstance(body["trace"], list)
assert len(body["trace"]) > 0
first = body["trace"][0]
assert isinstance(first["repeat"], int)
assert isinstance(first["fold"], int)
assert isinstance(first["iteration"], int)
assert first["selected"] in ("true", "false")
assert first["evaluation"] is None or isinstance(first["evaluation"], float)


@pytest.mark.parametrize("run_id", [24])
async def test_get_run_trace_no_trace(run_id: int, py_api: httpx.AsyncClient) -> None:
"""Test that 412 is returned for a run that exists but has no trace."""
response = await py_api.get(f"/runs/trace/{run_id}")
assert response.status_code == HTTPStatus.PRECONDITION_FAILED
body = response.json()
assert body["code"] == "572" # RunTraceNotFoundError code
assert body["type"] == RunTraceNotFoundError.uri
assert body["title"] == RunTraceNotFoundError.title
assert body["status"] == HTTPStatus.PRECONDITION_FAILED


@pytest.mark.parametrize("run_id", [999999])
async def test_get_run_trace_run_not_found(run_id: int, py_api: httpx.AsyncClient) -> None:
"""Test that 412 is returned when the run does not exist."""
response = await py_api.get(f"/runs/trace/{run_id}")
assert response.status_code == HTTPStatus.PRECONDITION_FAILED
body = response.json()
assert body["code"] == "571" # RunNotFoundError code
assert body["type"] == RunNotFoundError.uri
assert body["title"] == RunNotFoundError.title
assert body["status"] == HTTPStatus.PRECONDITION_FAILED
Loading