diff --git a/docs/migration.md b/docs/migration.md index 2c349c82..9945fad1 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -109,6 +109,9 @@ For example, after tagging dataset 21 with the tag `"foo"`: ## Setups +### `GET /{id}` +The endpoint behaves almost identically to the PHP implementation. Note that fields representing integers like `setup_id` and `flow_id` are returned as integers instead of strings to align with typed JSON. Also, if a setup has no parameters, the `parameter` field is omitted entirely from the response. + ### `POST /setup/tag` and `POST /setup/untag` When successful, the "tag" property in the returned response is now always a list, even if only one tag exists for the entity. When removing the last tag, the "tag" property will be an empty list `[]` instead of being omitted from the response. diff --git a/src/core/conversions.py b/src/core/conversions.py index 07ec71ef..5ac13c4a 100644 --- a/src/core/conversions.py +++ b/src/core/conversions.py @@ -1,3 +1,4 @@ +import math from collections.abc import Iterable, Mapping, Sequence from typing import Any @@ -7,9 +8,13 @@ def _str_to_num(string: str) -> int | float | str: if string.isdigit(): return int(string) try: - return float(string) + f = float(string) + if math.isnan(f) or math.isinf(f): + return string except ValueError: return string + else: + return f def nested_str_to_num(obj: Any) -> Any: @@ -42,17 +47,21 @@ def nested_num_to_str(obj: Any) -> Any: return obj -def nested_remove_nones(obj: Any) -> Any: +def nested_remove_values(obj: Any, *, values: list[Any]) -> Any: if isinstance(obj, str): return obj if isinstance(obj, Mapping): return { - key: nested_remove_nones(val) + key: nested_remove_values(val, values=values) for key, val in obj.items() - if val is not None and nested_remove_nones(val) is not None + if nested_remove_values(val, values=values) not in values } if isinstance(obj, Iterable): - return [nested_remove_nones(val) for val in obj if nested_remove_nones(val) is not None] + return [ + nested_remove_values(val, values=values) + for val in obj + if nested_remove_values(val, values=values) not in values + ] return obj diff --git a/src/database/setups.py b/src/database/setups.py index e399f194..74498478 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -1,7 +1,7 @@ """All database operations that directly operate on setups.""" from sqlalchemy import text -from sqlalchemy.engine import Row +from sqlalchemy.engine import Row, RowMapping from sqlalchemy.ext.asyncio import AsyncConnection @@ -20,6 +20,33 @@ async def get(setup_id: int, connection: AsyncConnection) -> Row | None: return row.first() +async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[RowMapping]: + """Get all parameters for setup with `setup_id` from the database.""" + rows = await connection.execute( + text( + """ + SELECT + t_input.id as id, + t_input.implementation_id as flow_id, + t_impl.name AS flow_name, + CONCAT(t_impl.fullName, '_', t_input.name) AS full_name, + t_input.name AS parameter_name, + t_input.name AS name, + t_input.dataType AS data_type, + t_input.defaultValue AS default_value, + t_setting.value AS value + FROM input_setting t_setting + JOIN input t_input ON t_setting.input_id = t_input.id + JOIN implementation t_impl ON t_input.implementation_id = t_impl.id + WHERE t_setting.setup = :setup_id + ORDER BY t_impl.id, t_input.id + """, + ), + parameters={"setup_id": setup_id}, + ) + return list(rows.mappings().all()) + + async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]: """Get all tags for setup with `setup_id` from the database.""" rows = await connection.execute( diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py index 65d2d533..085eb169 100644 --- a/src/routers/openml/setups.py +++ b/src/routers/openml/setups.py @@ -2,7 +2,7 @@ from typing import Annotated -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Path from sqlalchemy.ext.asyncio import AsyncConnection import database.setups @@ -15,10 +15,33 @@ from database.users import User, UserGroup from routers.dependencies import expdb_connection, fetch_user_or_raise from routers.types import SystemString64 +from schemas.setups import SetupParameters, SetupResponse router = APIRouter(prefix="/setup", tags=["setup"]) +@router.get(path="/{setup_id}", response_model_exclude_none=True) +async def get_setup( + setup_id: Annotated[int, Path()], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> SetupResponse: + """Get setup by id.""" + setup = await database.setups.get(setup_id, expdb_db) + if not setup: + msg = f"Setup {setup_id} not found." + raise SetupNotFoundError(msg, code=281) + + setup_parameters = await database.setups.get_parameters(setup_id, expdb_db) + + params_model = SetupParameters( + setup_id=setup_id, + flow_id=setup.implementation_id, + parameter=setup_parameters or None, + ) + + return SetupResponse(setup_parameters=params_model) + + @router.post(path="/tag") async def tag_setup( setup_id: Annotated[int, Body()], diff --git a/src/schemas/setups.py b/src/schemas/setups.py new file mode 100644 index 00000000..b4869dbe --- /dev/null +++ b/src/schemas/setups.py @@ -0,0 +1,37 @@ +"""Pydantic schemas for the setup API endpoints.""" + +from pydantic import BaseModel, ConfigDict + + +class SetupParameter(BaseModel): + """Schema representing an individual parameter within a setup.""" + + id: int + flow_id: int + flow_name: str + full_name: str + parameter_name: str + name: str + data_type: str | None = None + default_value: str | None = None + value: str | None = None + + model_config = ConfigDict(from_attributes=True) + + +class SetupParameters(BaseModel): + """Schema representing the grouped properties of a setup and its parameters.""" + + setup_id: int + flow_id: int + parameter: list[SetupParameter] | None = None + + model_config = ConfigDict(from_attributes=True) + + +class SetupResponse(BaseModel): + """Schema for the complete response of the GET /setup/{id} endpoint.""" + + setup_parameters: SetupParameters + + model_config = ConfigDict(from_attributes=True) diff --git a/tests/routers/openml/migration/setups_migration_test.py b/tests/routers/openml/migration/setups_migration_test.py index b33742e0..93228645 100644 --- a/tests/routers/openml/migration/setups_migration_test.py +++ b/tests/routers/openml/migration/setups_migration_test.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import re from collections.abc import AsyncGenerator, Callable, Iterable @@ -9,6 +10,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection +from core.conversions import nested_remove_values, nested_str_to_num from tests.conftest import temporary_records from tests.users import OWNER_USER, ApiKey @@ -266,6 +268,55 @@ async def test_setup_tag_response_is_identical_tag_already_exists( assert original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR assert new.status_code == HTTPStatus.CONFLICT - assert original.json()["error"]["code"] == new.json()["code"] assert original.json()["error"]["message"] == "Entity already tagged by this tag." assert new.json()["detail"] == f"Setup {setup_id} already has tag {tag!r}." + + +async def test_get_setup_response_is_identical_setup_doesnt_exist( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + setup_id = 999999 + + original, new = await asyncio.gather( + php_api.get(f"/setup/{setup_id}"), + py_api.get(f"/setup/{setup_id}"), + ) + + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.NOT_FOUND + assert original.json()["error"]["message"] == "Unknown setup" + assert original.json()["error"]["code"] == new.json()["code"] + assert new.json()["detail"] == f"Setup {setup_id} not found." + + +@pytest.mark.parametrize("setup_id", range(1, 125)) +async def test_get_setup_response_is_identical( + setup_id: int, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + original, new = await asyncio.gather( + php_api.get(f"/setup/{setup_id}"), + py_api.get(f"/setup/{setup_id}"), + ) + + if original.status_code == HTTPStatus.PRECONDITION_FAILED: + assert new.status_code == HTTPStatus.NOT_FOUND + return + + assert original.status_code == HTTPStatus.OK + assert new.status_code == HTTPStatus.OK + + original_json = original.json() + + # PHP returns integer fields as strings. To compare, we recursively convert string digits + # to integers. + # PHP also returns `[]` instead of null for empty string optional fields, which Python omits. + original_json = nested_str_to_num(original_json) + original_json = nested_remove_values(original_json, values=[[], None]) + + new_json = nested_str_to_num(new.json()) + new_json = nested_remove_values(new_json, values=[[], None]) + + assert original_json == new_json diff --git a/tests/routers/openml/migration/studies_migration_test.py b/tests/routers/openml/migration/studies_migration_test.py index 550ca686..07cdd0cb 100644 --- a/tests/routers/openml/migration/studies_migration_test.py +++ b/tests/routers/openml/migration/studies_migration_test.py @@ -3,7 +3,7 @@ import deepdiff import httpx -from core.conversions import nested_num_to_str, nested_remove_nones +from core.conversions import nested_num_to_str, nested_remove_values async def test_get_study_equal(py_api: httpx.AsyncClient, php_api: httpx.AsyncClient) -> None: @@ -17,7 +17,7 @@ async def test_get_study_equal(py_api: httpx.AsyncClient, php_api: httpx.AsyncCl # New implementation is typed new_json = nested_num_to_str(new_json) # New implementation has same fields even if empty - new_json = nested_remove_nones(new_json) + new_json = nested_remove_values(new_json, values=[None]) new_json["tasks"] = {"task_id": new_json.pop("task_ids")} new_json["data"] = {"data_id": new_json.pop("data_ids")} if runs := new_json.pop("run_ids", None): diff --git a/tests/routers/openml/migration/tasks_migration_test.py b/tests/routers/openml/migration/tasks_migration_test.py index f71a1e2c..eb2297d4 100644 --- a/tests/routers/openml/migration/tasks_migration_test.py +++ b/tests/routers/openml/migration/tasks_migration_test.py @@ -7,8 +7,8 @@ from core.conversions import ( nested_num_to_str, - nested_remove_nones, nested_remove_single_element_list, + nested_remove_values, ) @@ -32,7 +32,7 @@ async def test_get_task_equal( new_json["task_id"] = new_json.pop("id") new_json["task_name"] = new_json.pop("name") # PHP is not typed *and* automatically removes None values - new_json = nested_remove_nones(new_json) + new_json = nested_remove_values(new_json, values=[None]) new_json = nested_num_to_str(new_json) # It also removes "value" entries for parameters if the list is empty, # it does not remove *all* empty lists, e.g., for cost_matrix input they are kept diff --git a/tests/routers/openml/setups_test.py b/tests/routers/openml/setups_test.py index 305bf423..ca4e6cd2 100644 --- a/tests/routers/openml/setups_test.py +++ b/tests/routers/openml/setups_test.py @@ -130,3 +130,17 @@ async def test_setup_tag_success(py_api: httpx.AsyncClient, expdb_test: AsyncCon text("SELECT * FROM setup_tag WHERE id = 1 AND tag = 'my_new_success_tag'") ) assert len(rows.all()) == 1 + + +async def test_get_setup_unknown(py_api: httpx.AsyncClient) -> None: + response = await py_api.get("/setup/999999") + assert response.status_code == HTTPStatus.NOT_FOUND + assert re.match(r"Setup \d+ not found.", response.json()["detail"]) + + +async def test_get_setup_success(py_api: httpx.AsyncClient) -> None: + response = await py_api.get("/setup/1") + assert response.status_code == HTTPStatus.OK + data = response.json()["setup_parameters"] + assert data["setup_id"] == 1 + assert "parameter" in data