Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ For example, after tagging dataset 21 with the tag `"foo"`:

## Setups

### `GET /{id}`
The endpoint behaves almost identically to the PHP implementation. Note that fields representing integers like `setup_id` and `flow_id` are returned as integers instead of strings to align with typed JSON. Also, if a setup has no parameters, the `parameter` field is omitted entirely from the response.

### `POST /setup/tag` and `POST /setup/untag`
When successful, the "tag" property in the returned response is now always a list, even if only one tag exists for the entity. When removing the last tag, the "tag" property will be an empty list `[]` instead of being omitted from the response.

Expand Down
19 changes: 14 additions & 5 deletions src/core/conversions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any

Expand All @@ -7,9 +8,13 @@ def _str_to_num(string: str) -> int | float | str:
if string.isdigit():
return int(string)
try:
return float(string)
f = float(string)
if math.isnan(f) or math.isinf(f):
return string
except ValueError:
return string
else:
return f


def nested_str_to_num(obj: Any) -> Any:
Expand Down Expand Up @@ -42,17 +47,21 @@ def nested_num_to_str(obj: Any) -> Any:
return obj


def nested_remove_nones(obj: Any) -> Any:
def nested_remove_values(obj: Any, *, values: list[Any]) -> Any:
if isinstance(obj, str):
return obj
if isinstance(obj, Mapping):
return {
key: nested_remove_nones(val)
key: nested_remove_values(val, values=values)
for key, val in obj.items()
if val is not None and nested_remove_nones(val) is not None
if nested_remove_values(val, values=values) not in values
}
if isinstance(obj, Iterable):
return [nested_remove_nones(val) for val in obj if nested_remove_nones(val) is not None]
return [
nested_remove_values(val, values=values)
for val in obj
if nested_remove_values(val, values=values) not in values
]
return obj


Expand Down
29 changes: 28 additions & 1 deletion src/database/setups.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""All database operations that directly operate on setups."""

from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection


Expand All @@ -20,6 +20,33 @@ async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
return row.first()


async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[RowMapping]:
"""Get all parameters for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
"""
SELECT
t_input.id as id,
t_input.implementation_id as flow_id,
t_impl.name AS flow_name,
CONCAT(t_impl.fullName, '_', t_input.name) AS full_name,
t_input.name AS parameter_name,
t_input.name AS name,
t_input.dataType AS data_type,
t_input.defaultValue AS default_value,
t_setting.value AS value
FROM input_setting t_setting
JOIN input t_input ON t_setting.input_id = t_input.id
JOIN implementation t_impl ON t_input.implementation_id = t_impl.id
WHERE t_setting.setup = :setup_id
ORDER BY t_impl.id, t_input.id
""",
),
parameters={"setup_id": setup_id},
)
return list(rows.mappings().all())


async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
"""Get all tags for setup with `setup_id` from the database."""
rows = await connection.execute(
Expand Down
25 changes: 24 additions & 1 deletion src/routers/openml/setups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Annotated

from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, Path
from sqlalchemy.ext.asyncio import AsyncConnection

import database.setups
Expand All @@ -15,10 +15,33 @@
from database.users import User, UserGroup
from routers.dependencies import expdb_connection, fetch_user_or_raise
from routers.types import SystemString64
from schemas.setups import SetupParameters, SetupResponse

router = APIRouter(prefix="/setup", tags=["setup"])


@router.get(path="/{setup_id}", response_model_exclude_none=True)
async def get_setup(
setup_id: Annotated[int, Path()],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> SetupResponse:
"""Get setup by id."""
setup = await database.setups.get(setup_id, expdb_db)
if not setup:
msg = f"Setup {setup_id} not found."
raise SetupNotFoundError(msg, code=281)

setup_parameters = await database.setups.get_parameters(setup_id, expdb_db)

params_model = SetupParameters(
setup_id=setup_id,
flow_id=setup.implementation_id,
parameter=setup_parameters or None,
)

return SetupResponse(setup_parameters=params_model)


@router.post(path="/tag")
async def tag_setup(
setup_id: Annotated[int, Body()],
Expand Down
37 changes: 37 additions & 0 deletions src/schemas/setups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Pydantic schemas for the setup API endpoints."""

from pydantic import BaseModel, ConfigDict


class SetupParameter(BaseModel):
"""Schema representing an individual parameter within a setup."""

id: int
flow_id: int
flow_name: str
full_name: str
parameter_name: str
name: str
data_type: str | None = None
default_value: str | None = None
value: str | None = None

model_config = ConfigDict(from_attributes=True)


class SetupParameters(BaseModel):
"""Schema representing the grouped properties of a setup and its parameters."""

setup_id: int
flow_id: int
parameter: list[SetupParameter] | None = None

model_config = ConfigDict(from_attributes=True)


class SetupResponse(BaseModel):
"""Schema for the complete response of the GET /setup/{id} endpoint."""

setup_parameters: SetupParameters

model_config = ConfigDict(from_attributes=True)
53 changes: 52 additions & 1 deletion tests/routers/openml/migration/setups_migration_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import contextlib
import re
from collections.abc import AsyncGenerator, Callable, Iterable
Expand All @@ -9,6 +10,7 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection

from core.conversions import nested_remove_values, nested_str_to_num
from tests.conftest import temporary_records
from tests.users import OWNER_USER, ApiKey

Expand Down Expand Up @@ -266,6 +268,55 @@ async def test_setup_tag_response_is_identical_tag_already_exists(

assert original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
assert new.status_code == HTTPStatus.CONFLICT
assert original.json()["error"]["code"] == new.json()["code"]
assert original.json()["error"]["message"] == "Entity already tagged by this tag."
assert new.json()["detail"] == f"Setup {setup_id} already has tag {tag!r}."


async def test_get_setup_response_is_identical_setup_doesnt_exist(
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
setup_id = 999999

original, new = await asyncio.gather(
php_api.get(f"/setup/{setup_id}"),
py_api.get(f"/setup/{setup_id}"),
)

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
assert new.status_code == HTTPStatus.NOT_FOUND
assert original.json()["error"]["message"] == "Unknown setup"
assert original.json()["error"]["code"] == new.json()["code"]
assert new.json()["detail"] == f"Setup {setup_id} not found."


@pytest.mark.parametrize("setup_id", range(1, 125))
async def test_get_setup_response_is_identical(
setup_id: int,
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
original, new = await asyncio.gather(
php_api.get(f"/setup/{setup_id}"),
py_api.get(f"/setup/{setup_id}"),
)

if original.status_code == HTTPStatus.PRECONDITION_FAILED:
assert new.status_code == HTTPStatus.NOT_FOUND
return

assert original.status_code == HTTPStatus.OK
assert new.status_code == HTTPStatus.OK

original_json = original.json()

# PHP returns integer fields as strings. To compare, we recursively convert string digits
# to integers.
# PHP also returns `[]` instead of null for empty string optional fields, which Python omits.
original_json = nested_str_to_num(original_json)
original_json = nested_remove_values(original_json, values=[[], None])

new_json = nested_str_to_num(new.json())
new_json = nested_remove_values(new_json, values=[[], None])

assert original_json == new_json
4 changes: 2 additions & 2 deletions tests/routers/openml/migration/studies_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import deepdiff
import httpx

from core.conversions import nested_num_to_str, nested_remove_nones
from core.conversions import nested_num_to_str, nested_remove_values


async def test_get_study_equal(py_api: httpx.AsyncClient, php_api: httpx.AsyncClient) -> None:
Expand All @@ -17,7 +17,7 @@ async def test_get_study_equal(py_api: httpx.AsyncClient, php_api: httpx.AsyncCl
# New implementation is typed
new_json = nested_num_to_str(new_json)
# New implementation has same fields even if empty
new_json = nested_remove_nones(new_json)
new_json = nested_remove_values(new_json, values=[None])
new_json["tasks"] = {"task_id": new_json.pop("task_ids")}
new_json["data"] = {"data_id": new_json.pop("data_ids")}
if runs := new_json.pop("run_ids", None):
Expand Down
4 changes: 2 additions & 2 deletions tests/routers/openml/migration/tasks_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from core.conversions import (
nested_num_to_str,
nested_remove_nones,
nested_remove_single_element_list,
nested_remove_values,
)


Expand All @@ -32,7 +32,7 @@ async def test_get_task_equal(
new_json["task_id"] = new_json.pop("id")
new_json["task_name"] = new_json.pop("name")
# PHP is not typed *and* automatically removes None values
new_json = nested_remove_nones(new_json)
new_json = nested_remove_values(new_json, values=[None])
new_json = nested_num_to_str(new_json)
# It also removes "value" entries for parameters if the list is empty,
# it does not remove *all* empty lists, e.g., for cost_matrix input they are kept
Expand Down
14 changes: 14 additions & 0 deletions tests/routers/openml/setups_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,17 @@ async def test_setup_tag_success(py_api: httpx.AsyncClient, expdb_test: AsyncCon
text("SELECT * FROM setup_tag WHERE id = 1 AND tag = 'my_new_success_tag'")
)
assert len(rows.all()) == 1


async def test_get_setup_unknown(py_api: httpx.AsyncClient) -> None:
response = await py_api.get("/setup/999999")
assert response.status_code == HTTPStatus.NOT_FOUND
assert re.match(r"Setup \d+ not found.", response.json()["detail"])


async def test_get_setup_success(py_api: httpx.AsyncClient) -> None:
response = await py_api.get("/setup/1")
assert response.status_code == HTTPStatus.OK
data = response.json()["setup_parameters"]
assert data["setup_id"] == 1
assert "parameter" in data
Loading