Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/database/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text


def get_run(run_id: int, expdb: Connection) -> Row | None:
"""Check if a run exists. Used to distinguish 571 (run not found) from 572 (no trace)."""
return expdb.execute(
text(
"""
SELECT rid
FROM run
WHERE rid = :run_id
""",
),
parameters={"run_id": run_id},
).one_or_none()


def get_trace(run_id: int, expdb: Connection) -> Sequence[Row]:
"""Fetch all trace iterations for a run, ordered as PHP does: repeat, fold, iteration."""
return cast(
"Sequence[Row]",
expdb.execute(
text(
"""
SELECT `repeat`, `fold`, `iteration`, setup_string, evaluation, selected
FROM trace
WHERE run_id = :run_id
ORDER BY `repeat` ASC, `fold` ASC, `iteration` ASC
""",
),
parameters={"run_id": run_id},
).all(),
)
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from routers.openml.evaluations import router as evaluationmeasures_router
from routers.openml.flows import router as flows_router
from routers.openml.qualities import router as qualities_router
from routers.openml.runs import router as runs_router
from routers.openml.setups import router as setup_router
from routers.openml.study import router as study_router
from routers.openml.tasks import router as task_router
Expand Down Expand Up @@ -69,6 +70,7 @@ def create_api() -> FastAPI:
app.include_router(task_router)
app.include_router(flows_router)
app.include_router(study_router)
app.include_router(runs_router)
app.include_router(setup_router)
return app

Expand Down
56 changes: 56 additions & 0 deletions src/routers/openml/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from http import HTTPStatus
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import Connection

import database.runs
from routers.dependencies import expdb_connection
from schemas.runs import RunTrace, RunTraceResponse, TraceIteration

router = APIRouter(prefix="/runs", tags=["runs"])


@router.get("/trace/{run_id}")
def get_run_trace(
run_id: int,
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> RunTraceResponse:
"""Get the optimization trace for a run.

Returns all hyperparameter configurations tried during tuning, their
evaluations, and whether each was selected. Mirrors PHP API behavior.
"""
# 571: run does not exist at all
if not database.runs.get_run(run_id, expdb):
raise HTTPException(
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "571", "message": "Run not found."},
)

trace_rows = database.runs.get_trace(run_id, expdb)

# 572: run exists but has no trace data
if not trace_rows:
raise HTTPException(
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "572", "message": "No trace found for run."},
)

return RunTraceResponse(
trace=RunTrace(
# Cast to str: PHP returns run_id and all iteration fields as strings.
run_id=str(run_id),
trace_iteration=[
TraceIteration(
repeat=str(row.repeat),
fold=str(row.fold),
iteration=str(row.iteration),
setup_string=row.setup_string,
evaluation=row.evaluation,
selected=row.selected,
)
for row in trace_rows
],
),
)
22 changes: 22 additions & 0 deletions src/schemas/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Literal

from pydantic import BaseModel


class TraceIteration(BaseModel):
repeat: str
fold: str
iteration: str
setup_string: str
evaluation: str
selected: Literal["true", "false"]


class RunTrace(BaseModel):
run_id: str
trace_iteration: list[TraceIteration]


# Wraps RunTrace in {"trace": {...}} to match PHP API response structure.
class RunTraceResponse(BaseModel):
trace: RunTrace
71 changes: 71 additions & 0 deletions tests/routers/openml/runs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from http import HTTPStatus

import pytest
from starlette.testclient import TestClient


@pytest.mark.parametrize("run_id", [34])
def test_get_run_trace(py_api: TestClient, run_id: int) -> None:
response = py_api.get(f"/runs/trace/{run_id}")
assert response.status_code == HTTPStatus.OK

body = response.json()
assert "trace" in body

trace = body["trace"]
assert trace["run_id"] == str(run_id)
assert "trace_iteration" in trace
assert len(trace["trace_iteration"]) > 0

# Verify structure and types of each iteration — PHP returns all fields as strings
for iteration in trace["trace_iteration"]:
assert "repeat" in iteration
assert "fold" in iteration
assert "iteration" in iteration
assert "setup_string" in iteration
assert "evaluation" in iteration
assert "selected" in iteration
assert isinstance(iteration["repeat"], str)
assert isinstance(iteration["fold"], str)
assert isinstance(iteration["iteration"], str)
assert isinstance(iteration["setup_string"], str)
assert isinstance(iteration["evaluation"], str)
assert iteration["selected"] in ("true", "false")


def test_get_run_trace_ordering(py_api: TestClient) -> None:
"""Trace iterations must be ordered by repeat, fold, iteration ASC — matches PHP."""
response = py_api.get("/runs/trace/34")
assert response.status_code == HTTPStatus.OK

iterations = response.json()["trace"]["trace_iteration"]
keys = [(int(i["repeat"]), int(i["fold"]), int(i["iteration"])) for i in iterations]
assert keys == sorted(keys)


def test_get_run_trace_run_not_found(py_api: TestClient) -> None:
"""Run does not exist at all — expect error 571."""
response = py_api.get("/runs/trace/999999")
assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"]["code"] == "571"


def test_get_run_trace_negative_id(py_api: TestClient) -> None:
"""Negative run_id can never exist — expect error 571."""
response = py_api.get("/runs/trace/-1")
assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"]["code"] == "571"


def test_get_run_trace_invalid_id(py_api: TestClient) -> None:
"""Non-integer run_id — FastAPI should reject with 422 before hitting our handler."""
response = py_api.get("/runs/trace/abc")
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY


def test_get_run_trace_no_trace(py_api: TestClient) -> None:
"""Run exists but has no trace data — expect error 572.
Run 24 exists in the test DB but has no trace rows."""
response = py_api.get("/runs/trace/24")
assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"]["code"] == "572"
Loading