From d69b05b4c1e71aafc29928be615198f5ad8bcede Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sun, 3 May 2026 20:21:20 +0200 Subject: [PATCH 01/16] Start work on dataset untag --- src/database/datasets.py | 34 ++++++++++++++++ src/routers/openml/datasets.py | 20 +++++++++ tests/routers/openml/dataset_untag_test.py | 47 ++++++++++++++++++++++ 3 files changed, 101 insertions(+) create mode 100644 tests/routers/openml/dataset_untag_test.py diff --git a/src/database/datasets.py b/src/database/datasets.py index d6f9170..a99ac9b 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,6 +2,7 @@ import datetime from collections import defaultdict +from typing import NamedTuple from sqlalchemy import text from sqlalchemy.engine import Row @@ -45,6 +46,39 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: return row.one_or_none() +class Tag(NamedTuple): + content: str + creator: int + + +async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Tag | None: + row = ( + await connection.execute( + text( + """ + SELECT * + FROM dataset_tag + WHERE id = :dataset_id AND tag = :tag + """, + ), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + ).first() + return Tag(content=row.tag, creator=row.uploader) if row else None + + +async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None: + await connection.execute( + text( + """ + DELETE FROM dataset_tag + WHERE id = :dataset_id AND tag = :tag + """, + ), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]: row = await connection.execute( text( diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 68d86ae..e727101 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -2,6 +2,7 @@ import re from datetime import datetime from enum import StrEnum +from http import HTTPStatus from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends @@ -26,6 +27,8 @@ InternalError, NoResultsError, TagAlreadyExistsError, + TagNotFoundError, + TagNotOwnedError, ) from core.formatting import ( _csv_as_list, @@ -80,6 +83,23 @@ async def tag_dataset( } +@router.post(path="/untag", status_code=HTTPStatus.NO_CONTENT) +async def untag_dataset( + data_id: Annotated[Identifier, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> None: + dataset_tag = await database.datasets.get_tag(data_id, tag, expdb_db) + if not dataset_tag: + msg = f"Tag {tag!r} for dataset {data_id} not found." + raise TagNotFoundError(msg) + if dataset_tag.creator != user.user_id and not (await user.is_admin()): + msg = f"You are not allowed to remove {tag!r} from dataset {data_id}." + raise TagNotOwnedError(msg) + await database.datasets.delete_tag(data_id, tag, expdb_db) + + class DatasetStatusFilter(StrEnum): ACTIVE = DatasetStatus.ACTIVE DEACTIVATED = DatasetStatus.DEACTIVATED diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py new file mode 100644 index 0000000..d926a73 --- /dev/null +++ b/tests/routers/openml/dataset_untag_test.py @@ -0,0 +1,47 @@ +from http import HTTPStatus + +import httpx +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection + +from core.errors import TagNotFoundError +from routers.openml.datasets import untag_dataset +from tests.users import SOME_USER, ApiKey + + +async def test_dataset_untag_success( + py_api: httpx.AsyncClient, expdb_test: AsyncConnection +) -> None: + dataset_id = 1 + tag = "foo" + await expdb_test.execute( + text("INSERT INTO dataset_tag(id, tag, uploader) VALUES (:dataset_id, :tag, 2)"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + + assert response.status_code == HTTPStatus.NO_CONTENT + tag_present = await expdb_test.execute( + text("SELECT 1 FROM dataset_tag WHERE id=:dataset_id AND tag=:tag"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + assert tag_present.scalar() is None + + +async def test_dataset_untag_tag_does_not_exist(expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "foo" + with pytest.raises(TagNotFoundError) as e: + await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + assert e.value.status_code == HTTPStatus.NOT_FOUND + assert tag in e.value.detail + assert str(dataset_id) in e.value.detail + + +# Dataset doesn't exist +# Tag is not owned From 06de426d7d7993c6eb42f84524023de6ca51ff98 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sun, 3 May 2026 20:25:46 +0200 Subject: [PATCH 02/16] Add ownership tests --- tests/routers/openml/dataset_untag_test.py | 40 ++++++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index d926a73..9272a2c 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -5,9 +5,9 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection -from core.errors import TagNotFoundError +from core.errors import TagNotFoundError, TagNotOwnedError from routers.openml.datasets import untag_dataset -from tests.users import SOME_USER, ApiKey +from tests.users import ADMIN_USER, SOME_USER, ApiKey async def test_dataset_untag_success( @@ -36,12 +36,46 @@ async def test_dataset_untag_success( async def test_dataset_untag_tag_does_not_exist(expdb_test: AsyncConnection) -> None: dataset_id = 1 tag = "foo" + with pytest.raises(TagNotFoundError) as e: await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + assert e.value.status_code == HTTPStatus.NOT_FOUND assert tag in e.value.detail assert str(dataset_id) in e.value.detail +async def test_dataset_untag_tag_not_owned(expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "foo" + await expdb_test.execute( + text("INSERT INTO dataset_tag(id, tag, uploader) VALUES (:dataset_id, :tag, 1)"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + with pytest.raises(TagNotOwnedError) as e: + await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert tag in e.value.detail + assert str(dataset_id) in e.value.detail + + +async def test_dataset_untag_admin_bypasses_ownership(expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "foo" + await expdb_test.execute( + text("INSERT INTO dataset_tag(id, tag, uploader) VALUES (:dataset_id, :tag, 1)"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + await untag_dataset(dataset_id, tag, ADMIN_USER, expdb_test) + + tag_present = await expdb_test.execute( + text("SELECT 1 FROM dataset_tag WHERE id=:dataset_id AND tag=:tag"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + assert tag_present.scalar() is None + + # Dataset doesn't exist -# Tag is not owned From 05b05e0094c5da352a0c4c3489c54d29988c8a3b Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sun, 3 May 2026 20:37:41 +0200 Subject: [PATCH 03/16] Just stick with row for now like the rest --- src/database/datasets.py | 11 ++--------- src/routers/openml/datasets.py | 2 +- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index a99ac9b..22a9868 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,7 +2,6 @@ import datetime from collections import defaultdict -from typing import NamedTuple from sqlalchemy import text from sqlalchemy.engine import Row @@ -46,13 +45,8 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: return row.one_or_none() -class Tag(NamedTuple): - content: str - creator: int - - -async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Tag | None: - row = ( +async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row | None: + return ( await connection.execute( text( """ @@ -64,7 +58,6 @@ async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Tag parameters={"dataset_id": dataset_id, "tag": tag}, ) ).first() - return Tag(content=row.tag, creator=row.uploader) if row else None async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None: diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index e727101..e9ae8d5 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -94,7 +94,7 @@ async def untag_dataset( if not dataset_tag: msg = f"Tag {tag!r} for dataset {data_id} not found." raise TagNotFoundError(msg) - if dataset_tag.creator != user.user_id and not (await user.is_admin()): + if dataset_tag.uploader != user.user_id and not (await user.is_admin()): msg = f"You are not allowed to remove {tag!r} from dataset {data_id}." raise TagNotOwnedError(msg) await database.datasets.delete_tag(data_id, tag, expdb_db) From c7851ad5819f4dde3d25e51384756c5bbbadc16b Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sun, 3 May 2026 20:41:25 +0200 Subject: [PATCH 04/16] Return separate error if dataset does not exist Mainly for compatibility with the PHP REST API --- src/routers/openml/datasets.py | 4 ++++ tests/routers/openml/dataset_untag_test.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index e9ae8d5..b375e03 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -92,6 +92,10 @@ async def untag_dataset( ) -> None: dataset_tag = await database.datasets.get_tag(data_id, tag, expdb_db) if not dataset_tag: + dataset = await database.datasets.get(data_id, expdb_db) + if not dataset: + msg = f"Cannot remove {tag!r}, because dataset {data_id} is not found." + raise DatasetNotFoundError(msg, code=472) msg = f"Tag {tag!r} for dataset {data_id} not found." raise TagNotFoundError(msg) if dataset_tag.uploader != user.user_id and not (await user.is_admin()): diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index 9272a2c..c83193e 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -5,7 +5,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection -from core.errors import TagNotFoundError, TagNotOwnedError +from core.errors import DatasetNotFoundError, TagNotFoundError, TagNotOwnedError from routers.openml.datasets import untag_dataset from tests.users import ADMIN_USER, SOME_USER, ApiKey @@ -78,4 +78,13 @@ async def test_dataset_untag_admin_bypasses_ownership(expdb_test: AsyncConnectio assert tag_present.scalar() is None -# Dataset doesn't exist +async def test_dataset_untag_dataset_is_not_exist(expdb_test: AsyncConnection) -> None: + dataset_id = 9_999_999 + tag = "foo" + + with pytest.raises(DatasetNotFoundError) as e: + await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + + assert e.value.status_code == HTTPStatus.NOT_FOUND + assert tag in e.value.detail + assert str(dataset_id) in e.value.detail From 8dcd671cdc2a2df18a4d3ad9d3d5e3bd7fd22ce7 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 10:31:47 +0200 Subject: [PATCH 05/16] Make sure the tag remains in the database after denied request --- tests/routers/openml/dataset_untag_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index c83193e..f6d0095 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -60,6 +60,12 @@ async def test_dataset_untag_tag_not_owned(expdb_test: AsyncConnection) -> None: assert tag in e.value.detail assert str(dataset_id) in e.value.detail + tag_present = await expdb_test.execute( + text("SELECT 1 FROM dataset_tag WHERE id=:dataset_id AND tag=:tag"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + assert tag_present.scalar() == 1 + async def test_dataset_untag_admin_bypasses_ownership(expdb_test: AsyncConnection) -> None: dataset_id = 1 From 83a26325ef127efff5a046326942551cc9e2d206 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 10:49:11 +0200 Subject: [PATCH 06/16] Make temporary tags usable for other types of tags --- tests/conftest.py | 19 ++++++++++++++----- tests/routers/openml/dataset_untag_test.py | 6 ++++++ tests/routers/openml/setups_untag_test.py | 4 ++-- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 368b789..897e560 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -175,19 +175,28 @@ def temporary_tags( ) -> Callable[..., contextlib.AbstractAsyncContextManager[None]]: @contextlib.asynccontextmanager async def _temporary_tags( - tags: Iterable[str], setup_id: int, *, persist: bool = False + table: str, + tags: Iterable[str], + identifier: int, + *, + persist: bool = False, ) -> AsyncIterator[None]: insert_queries = [ ( - "INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);", - {"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id}, + f"INSERT INTO {table}(`id`,`tag`,`uploader`) VALUES (:identifier, :tag, :user_id);", # noqa: S608 # No user provided values + { + "table": table, + "identifier": identifier, + "tag": tag, + "user_id": OWNER_USER.user_id, + }, ) for tag in tags ] delete_queries = [ ( - "DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag", - {"setup_id": setup_id, "tag": tag}, + f"DELETE FROM {table} WHERE `id`=:identifier AND `tag`=:tag", # noqa: S608 # No user provided values + {"identifier": identifier, "tag": tag}, ) for tag in tags ] diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index f6d0095..e8369e7 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -94,3 +94,9 @@ async def test_dataset_untag_dataset_is_not_exist(expdb_test: AsyncConnection) - assert e.value.status_code == HTTPStatus.NOT_FOUND assert tag in e.value.detail assert str(dataset_id) in e.value.detail + + +@pytest.mark.mut +async def test_dataset_untag_is_identical( + py_api: httpx.AsyncClient, php_api: httpx.AsyncClient +) -> None: ... diff --git a/tests/routers/openml/setups_untag_test.py b/tests/routers/openml/setups_untag_test.py index 1ed7b42..55d4b28 100644 --- a/tests/routers/openml/setups_untag_test.py +++ b/tests/routers/openml/setups_untag_test.py @@ -144,7 +144,7 @@ async def test_setup_untag_response_is_identical_when_tag_exists( tag = "totally_new_tag_for_migration_testing" all_tags = [tag, *other_tags] - async with temporary_tags(tags=all_tags, setup_id=setup_id, persist=True): + async with temporary_tags(table="setup_tag", tags=all_tags, identifier=setup_id, persist=True): php_response = await php_api.post( "/setup/untag", data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, @@ -152,7 +152,7 @@ async def test_setup_untag_response_is_identical_when_tag_exists( # expdb_test transaction shared with Python API, # no commit needed and rolled back at the end of the test - async with temporary_tags(tags=all_tags, setup_id=setup_id): + async with temporary_tags(table="setup_tag", tags=all_tags, identifier=setup_id): py_response = await py_api.post( f"/setup/untag?api_key={api_key}", json={"setup_id": setup_id, "tag": tag}, From 638c58fc7dad291c2ff4c3a17fe67d2f9be06440 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 11:15:24 +0200 Subject: [PATCH 07/16] Write migration test --- tests/routers/openml/dataset_untag_test.py | 25 +++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index e8369e7..9329d24 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -1,3 +1,5 @@ +from collections.abc import Callable +from contextlib import AbstractAsyncContextManager from http import HTTPStatus import httpx @@ -97,6 +99,23 @@ async def test_dataset_untag_dataset_is_not_exist(expdb_test: AsyncConnection) - @pytest.mark.mut -async def test_dataset_untag_is_identical( - py_api: httpx.AsyncClient, php_api: httpx.AsyncClient -) -> None: ... +async def test_dataset_untag_success_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, + temporary_tags: Callable[..., AbstractAsyncContextManager[None]], +) -> None: + dataset_id = 1 + tag = "foo" + + async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id, persist=True): + php_response = await php_api.post( + f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} + ) + + async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id): + py_response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} + ) + + assert py_response.status_code == php_response.status_code + assert py_response.json() == php_response.json() From 78a7983c9748f4d20789f8e7fa825592420191b1 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 14:31:17 +0200 Subject: [PATCH 08/16] Separate types definition from where they are used --- src/routers/openml/datasets.py | 32 ++++++++++----------- src/routers/openml/flows.py | 3 +- src/routers/openml/qualities.py | 3 +- src/routers/openml/runs.py | 3 +- src/routers/openml/setups.py | 12 ++++---- src/routers/openml/study.py | 9 +++--- src/routers/openml/tasks.py | 28 +++++++++++------- src/routers/types.py | 22 +++++++------- tests/routers/openml/datasets_get_test.py | 35 ++++++++++++----------- 9 files changed, 79 insertions(+), 68 deletions(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index b375e03..6c81eda 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -61,7 +61,7 @@ ) async def tag_dataset( data_id: Annotated[Identifier, Body()], - tag: Annotated[str, SystemString64], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, Any]]: @@ -86,7 +86,7 @@ async def tag_dataset( @router.post(path="/untag", status_code=HTTPStatus.NO_CONTENT) async def untag_dataset( data_id: Annotated[Identifier, Body()], - tag: Annotated[str, SystemString64], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> None: @@ -132,27 +132,27 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.get(path="/list") async def list_datasets( # noqa: PLR0913, C901 pagination: Annotated[Pagination, Body(default_factory=Pagination)], - data_name: Annotated[str | None, CasualString128] = None, - tag: Annotated[str | None, SystemString64] = None, + data_name: Annotated[CasualString128 | None, Body()] = None, + tag: Annotated[SystemString64 | None, Body()] = None, data_version: Annotated[ - int | None, + Identifier | None, Body(description="The dataset version to include in the search."), ] = None, uploader: Annotated[ - int | None, + Identifier | None, Body(description="User id of the uploader whose datasets to include in the search."), ] = None, data_id: Annotated[ - list[int] | None, + list[Identifier] | None, Body( description="The dataset(s) to include in the search. " "If none are specified, all datasets are included.", ), ] = None, - number_instances: Annotated[str | None, IntegerRange] = None, - number_features: Annotated[str | None, IntegerRange] = None, - number_classes: Annotated[str | None, IntegerRange] = None, - number_missing_values: Annotated[str | None, IntegerRange] = None, + number_instances: Annotated[IntegerRange | None, Body()] = None, + number_features: Annotated[IntegerRange | None, Body()] = None, + number_classes: Annotated[IntegerRange | None, Body()] = None, + number_missing_values: Annotated[IntegerRange | None, Body()] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, @@ -290,7 +290,7 @@ class ProcessingInformation(NamedTuple): async def _get_processing_information( - dataset_id: int, + dataset_id: Identifier, connection: AsyncConnection, ) -> ProcessingInformation: """Return processing information, if any. Otherwise, all fields `None`.""" @@ -309,7 +309,7 @@ async def _get_processing_information( async def _get_dataset_raise_otherwise( - dataset_id: int, + dataset_id: Identifier, user: User | None, expdb: AsyncConnection, ) -> Row[Any]: @@ -330,7 +330,7 @@ async def _get_dataset_raise_otherwise( @router.get("/features/{dataset_id}", response_model_exclude_none=True) async def get_dataset_features( - dataset_id: int, + dataset_id: Identifier, user: Annotated[User | None, Depends(fetch_user)] = None, expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[Feature]: @@ -373,7 +373,7 @@ async def get_dataset_features( path="/status/update", ) async def update_dataset_status( - dataset_id: Annotated[int, Body()], + dataset_id: Annotated[Identifier, Body()], status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)], @@ -427,7 +427,7 @@ async def update_dataset_status( description="Get meta-data for dataset with ID `dataset_id`.", ) async def get_dataset( - dataset_id: int, + dataset_id: Identifier, user: Annotated[User | None, Depends(fetch_user)] = None, user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 44c2338..b346793 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -8,6 +8,7 @@ from core.conversions import _str_to_num from core.errors import FlowNotFoundError from routers.dependencies import expdb_connection +from routers.types import Identifier from schemas.flows import Flow, Parameter, Subflow router = APIRouter(prefix="/flows", tags=["flows"]) @@ -33,7 +34,7 @@ async def flow_exists( @router.get("/{flow_id}") async def get_flow( - flow_id: int, + flow_id: Identifier, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> Flow: flow = await database.flows.get(flow_id, expdb) diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index eff7081..f7684ca 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -14,6 +14,7 @@ ) from database.users import User from routers.dependencies import expdb_connection, fetch_user +from routers.types import Identifier from schemas.datasets.openml import Quality router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -33,7 +34,7 @@ async def list_qualities( @router.get("/qualities/{dataset_id}") async def get_qualities( - dataset_id: int, + dataset_id: Identifier, user: Annotated[User | None, Depends(fetch_user)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> list[Quality]: diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 37a7cec..641108f 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -8,6 +8,7 @@ import database.runs from core.errors import RunNotFoundError, RunTraceNotFoundError from routers.dependencies import expdb_connection +from routers.types import Identifier from schemas.runs import RunTrace, TraceIteration router = APIRouter(prefix="/run", tags=["run"]) @@ -15,7 +16,7 @@ @router.get("/trace/{run_id}") async def get_run_trace( - run_id: int, + run_id: Identifier, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> RunTrace: """Get trace data for a run by run ID.""" diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py index 89abb24..2d8bd57 100644 --- a/src/routers/openml/setups.py +++ b/src/routers/openml/setups.py @@ -16,7 +16,7 @@ ) from database.users import User from routers.dependencies import expdb_connection, fetch_user_or_raise -from routers.types import SystemString64 +from routers.types import Identifier, SystemString64 from schemas.setups import SetupParameters, SetupResponse router = APIRouter(prefix="/setup", tags=["setup"]) @@ -24,7 +24,7 @@ @router.get(path="/{setup_id}", response_model_exclude_none=True) async def get_setup( - setup_id: Annotated[int, Path()], + setup_id: Annotated[Identifier, Path()], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> SetupResponse: """Get setup by id.""" @@ -46,8 +46,8 @@ async def get_setup( @router.post(path="/tag") async def tag_setup( - setup_id: Annotated[int, Body()], - tag: Annotated[str, SystemString64], + setup_id: Annotated[Identifier, Body()], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: @@ -73,8 +73,8 @@ async def tag_setup( @router.post(path="/untag") async def untag_setup( - setup_id: Annotated[int, Body()], - tag: Annotated[str, SystemString64], + setup_id: Annotated[Identifier, Body()], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 56c670c..7c4ae72 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -20,6 +20,7 @@ from core.formatting import _str_to_bool from database.users import User from routers.dependencies import expdb_connection, fetch_user, fetch_user_or_raise +from routers.types import Identifier from schemas.core import Visibility from schemas.study import CreateStudy, Study, StudyStatus, StudyType @@ -27,7 +28,7 @@ async def _get_study_raise_otherwise( - id_or_alias: int | str, + id_or_alias: Identifier | str, user: User | None, expdb: AsyncConnection, ) -> Row: @@ -61,8 +62,8 @@ class AttachDetachResponse(BaseModel): @router.post("/attach") async def attach_to_study( - study_id: Annotated[int, Body()], - entity_ids: Annotated[list[int], Body()], + study_id: Annotated[Identifier, Body()], + entity_ids: Annotated[list[Identifier], Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: @@ -148,7 +149,7 @@ async def create_study( @router.get("/{alias_or_id}") async def get_study( - alias_or_id: int | str, + alias_or_id: Identifier | str, user: Annotated[User | None, Depends(fetch_user)] = None, expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Study: diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 3dfa594..80cdacd 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -15,7 +15,13 @@ import database.tasks from core.errors import InternalError, NoResultsError, TaskNotFoundError from routers.dependencies import Pagination, expdb_connection -from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex +from routers.types import ( + CasualString128, + Identifier, + IntegerRange, + SystemString64, + integer_range_regex, +) from schemas.datasets.openml import Task router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -221,23 +227,23 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.get(path="/list") async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 pagination: Annotated[Pagination, Body(default_factory=Pagination)], - task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, - tag: Annotated[str | None, SystemString64] = None, - data_tag: Annotated[str | None, SystemString64] = None, + task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None, + tag: Annotated[SystemString64 | None, Body()] = None, + data_tag: Annotated[SystemString64 | None, Body()] = None, status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, task_id: Annotated[ - list[int] | None, + list[Identifier] | None, Body(description="Filter by task id(s).", min_length=1), ] = None, data_id: Annotated[ - list[int] | None, + list[Identifier] | None, Body(description="Filter by dataset id(s).", min_length=1), ] = None, - data_name: Annotated[str | None, CasualString128] = None, - number_instances: Annotated[str | None, IntegerRange] = None, - number_features: Annotated[str | None, IntegerRange] = None, - number_classes: Annotated[str | None, IntegerRange] = None, - number_missing_values: Annotated[str | None, IntegerRange] = None, + data_name: Annotated[CasualString128 | None, Body()] = None, + number_instances: Annotated[IntegerRange | None, Body()] = None, + number_features: Annotated[IntegerRange | None, Body()] = None, + number_classes: Annotated[IntegerRange | None, Body()] = None, + number_missing_values: Annotated[IntegerRange | None, Body()] = None, expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" diff --git a/src/routers/types.py b/src/routers/types.py index e107ff3..fcdb876 100644 --- a/src/routers/types.py +++ b/src/routers/types.py @@ -1,18 +1,18 @@ from typing import Annotated -from fastapi import Body from pydantic import Field -SystemString64 = Body(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64) - -CasualString128 = Body(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128) - +SystemString64 = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)] +CasualString128 = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)] Identifier = Annotated[int, Field(gt=0)] integer_range_regex = r"^(\d+)(\.\.\d+)?$" -IntegerRange = Body( - pattern=integer_range_regex, - description="Either a single integer, or a range defined as `low..high`, where" - "`low` and `high` are inclusive integer bounds of the range.", - examples=["12", "3..150"], -) +IntegerRange = Annotated[ + str, + Field( + pattern=integer_range_regex, + description="Either a single integer, or a range defined as `low..high`, where" + "`low` and `high` are inclusive integer bounds of the range.", + examples=["12", "3..150"], + ), +] diff --git a/tests/routers/openml/datasets_get_test.py b/tests/routers/openml/datasets_get_test.py index 4b9fb33..6bcd51b 100644 --- a/tests/routers/openml/datasets_get_test.py +++ b/tests/routers/openml/datasets_get_test.py @@ -89,20 +89,24 @@ async def test_dataset_no_500_with_multiple_processing_entries( @pytest.mark.parametrize( "dataset_id", - [-1, 138, 100_000], + [138, 100_000], ) async def test_get_dataset_not_found( dataset_id: int, expdb_test: AsyncConnection, user_test: AsyncConnection, ) -> None: - with pytest.raises(DatasetNotFoundError): + with pytest.raises(DatasetNotFoundError) as exc_info: await get_dataset( dataset_id=dataset_id, user=None, user_db=user_test, expdb_db=expdb_test, ) + assert exc_info.value.status_code == HTTPStatus.NOT_FOUND + _dataset_get_not_found_code = 111 + assert exc_info.value.code == _dataset_get_not_found_code + assert exc_info.value.detail.startswith("No dataset") @pytest.mark.parametrize( @@ -235,24 +239,21 @@ async def test_dataset_response_is_identical( # noqa: C901, PLR0912 assert py_json == php_json -@pytest.mark.parametrize( - "dataset_id", - [-1, 138, 100_000], -) -async def test_error_unknown_dataset( - dataset_id: int, +async def test_dataset_not_found_is_identical( py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, ) -> None: - response = await py_api.get(f"/datasets/{dataset_id}") + dataset_id = 9_999_999 + py_response, php_response = await asyncio.gather( + py_api.get(f"/datasets/{dataset_id}"), + php_api.get(f"/datasets/{dataset_id}"), + ) - # The new API has "404 Not Found" instead of "412 PRECONDITION_FAILED" - assert response.status_code == HTTPStatus.NOT_FOUND - # RFC 9457: Python API now returns problem+json format - assert response.headers["content-type"] == "application/problem+json" - error = response.json() - assert error["code"] == "111" - # instead of 'Unknown dataset' - assert error["detail"].startswith("No dataset") + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert py_response.json()["detail"] == f"Dataset {dataset_id} not found." + assert php_response.json()["error"]["message"] == "Dataset not found." async def test_private_dataset_no_user_no_access( From ac16c1e693dca296767ae14d15a18bd15af4020b Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 14:41:58 +0200 Subject: [PATCH 09/16] Add alternative more semantically correct untag endpoint --- pyproject.toml | 2 +- src/routers/openml/datasets.py | 49 +++++++++++++++------- tests/routers/openml/dataset_untag_test.py | 26 +++++++++--- 3 files changed, 56 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d8078bf..e1b7e60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ ] description = "The Python-based REST API for OpenML." readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.14" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 6c81eda..ba3e129 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -3,13 +3,11 @@ from datetime import datetime from enum import StrEnum from http import HTTPStatus -from typing import Annotated, Any, Literal, NamedTuple +from typing import TYPE_CHECKING, Annotated, Any, Literal, NamedTuple, TypedDict -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Query from loguru import logger from sqlalchemy import bindparam, text -from sqlalchemy.engine import Row -from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -53,6 +51,10 @@ ) from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -83,25 +85,44 @@ async def tag_dataset( } -@router.post(path="/untag", status_code=HTTPStatus.NO_CONTENT) -async def untag_dataset( +class UntagInfo(TypedDict): + id: str + tag: SystemString64 | list[SystemString64] + + +@router.post(path="/untag", deprecated=True) +async def untag_dataset_like_php( data_id: Annotated[Identifier, Body()], tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[Literal["data_untag"], UntagInfo]: + await untag_dataset(data_id, tag, user, expdb_db) + tags = await database.datasets.get_tags_for(id_=data_id, connection=expdb_db) + return_tags = tags if len(tags) > 1 else tags[0] + return {"data_untag": {"id": str(data_id), "tag": return_tags}} + + +@router.delete(path="/{identifier}/tag", status_code=HTTPStatus.NO_CONTENT) +async def untag_dataset( + identifier: Identifier, + tag: Annotated[SystemString64, Query()], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> None: - dataset_tag = await database.datasets.get_tag(data_id, tag, expdb_db) + dataset_tag = await database.datasets.get_tag(identifier, tag, expdb_db) if not dataset_tag: - dataset = await database.datasets.get(data_id, expdb_db) - if not dataset: - msg = f"Cannot remove {tag!r}, because dataset {data_id} is not found." - raise DatasetNotFoundError(msg, code=472) - msg = f"Tag {tag!r} for dataset {data_id} not found." + try: + await _get_dataset_raise_otherwise(identifier, user, expdb_db) + except DatasetNotFoundError, DatasetNoAccessError: + msg = f"Cannot remove {tag!r}, because dataset {identifier} is not found." + raise DatasetNotFoundError(msg, code=472) from None + msg = f"Tag {tag!r} for dataset {identifier} not found." raise TagNotFoundError(msg) if dataset_tag.uploader != user.user_id and not (await user.is_admin()): - msg = f"You are not allowed to remove {tag!r} from dataset {data_id}." + msg = f"You are not allowed to remove {tag!r} from dataset {identifier}." raise TagNotOwnedError(msg) - await database.datasets.delete_tag(data_id, tag, expdb_db) + await database.datasets.delete_tag(identifier, tag, expdb_db) class DatasetStatusFilter(StrEnum): diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index 9329d24..3b56262 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -1,16 +1,31 @@ +"""Tests for untagging a dataset. + +There are currently two endpoints for untagging a dataset: + + POST /datasets/untag + DEL /datasets/{id}/tag + +The former is provided for compatibility with the old API, and is tested in the migration test. +The latter is more semantically correct, and is used for the Python tests. +They share most of the underlying logic anyway. +""" + from collections.abc import Callable from contextlib import AbstractAsyncContextManager from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import DatasetNotFoundError, TagNotFoundError, TagNotOwnedError from routers.openml.datasets import untag_dataset from tests.users import ADMIN_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_dataset_untag_success( py_api: httpx.AsyncClient, expdb_test: AsyncConnection @@ -22,9 +37,8 @@ async def test_dataset_untag_success( parameters={"dataset_id": dataset_id, "tag": tag}, ) - response = await py_api.post( - f"/datasets/untag?api_key={ApiKey.SOME_USER}", - json={"data_id": dataset_id, "tag": tag}, + response = await py_api.delete( + f"/datasets/{dataset_id}/tag?api_key={ApiKey.SOME_USER}&tag={tag}", ) assert response.status_code == HTTPStatus.NO_CONTENT @@ -86,7 +100,7 @@ async def test_dataset_untag_admin_bypasses_ownership(expdb_test: AsyncConnectio assert tag_present.scalar() is None -async def test_dataset_untag_dataset_is_not_exist(expdb_test: AsyncConnection) -> None: +async def test_dataset_untag_dataset_does_not_exist(expdb_test: AsyncConnection) -> None: dataset_id = 9_999_999 tag = "foo" From 6d0bb5e774a3e0a7fdfff6066630aa0c4326e968 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 14:54:05 +0200 Subject: [PATCH 10/16] Add additional migration tests for untag --- tests/routers/openml/dataset_untag_test.py | 81 ++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index 3b56262..4d26d54 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -10,6 +10,7 @@ They share most of the underlying logic anyway. """ +import asyncio from collections.abc import Callable from contextlib import AbstractAsyncContextManager from http import HTTPStatus @@ -133,3 +134,83 @@ async def test_dataset_untag_success_is_identical( assert py_response.status_code == php_response.status_code assert py_response.json() == php_response.json() + + +@pytest.mark.mut +async def test_dataset_untag_tag_does_not_exist_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + dataset_id = 1 + tag = "foo" + + py_response, php_response = await asyncio.gather( + py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} + ), + php_api.post( + f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} + ), + ) + + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Tag not found." + assert tag in py_response.json()["detail"] + assert str(dataset_id) in py_response.json()["detail"] + assert "not found" in py_response.json()["detail"] + + +@pytest.mark.mut +async def test_dataset_untag_dataset_does_not_exist_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + dataset_id = 9_999_999 + tag = "foo" + + py_response, php_response = await asyncio.gather( + py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} + ), + php_api.post( + f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} + ), + ) + + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Entity not found." + assert tag in py_response.json()["detail"] + assert str(dataset_id) in py_response.json()["detail"] + assert "not found" in py_response.json()["detail"] + + +@pytest.mark.mut +async def test_dataset_untag_tag_not_owned_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, + temporary_tags: Callable[..., AbstractAsyncContextManager[None]], +) -> None: + dataset_id = 1 + tag = "foo" + + async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id, persist=True): + php_response = await php_api.post( + f"/data/untag?api_key={ApiKey.SOME_USER}", data={"tag": tag, "data_id": dataset_id} + ) + + async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id): + py_response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", json={"tag": tag, "data_id": dataset_id} + ) + + assert py_response.status_code == HTTPStatus.FORBIDDEN + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Tag is not owned by you" + assert tag in py_response.json()["detail"] + assert str(dataset_id) in py_response.json()["detail"] + assert "not allowed" in py_response.json()["detail"] From 768c66bde0c66ee5469712a6f343aed23cd8646a Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 15:12:24 +0200 Subject: [PATCH 11/16] Use TYPE_CHECKING to conditionally import packages --- src/core/access.py | 7 ++++--- src/core/errors.py | 7 +++++-- src/core/formatting.py | 6 ++++-- src/core/logging.py | 7 +++++-- src/database/datasets.py | 7 +++++-- src/database/evaluations.py | 6 ++++-- src/database/flows.py | 6 ++++-- src/database/qualities.py | 5 ++++- src/database/runs.py | 6 ++++-- src/database/setups.py | 8 ++++++-- src/database/studies.py | 6 ++++-- src/database/tasks.py | 6 ++++-- src/database/users.py | 6 ++++-- src/routers/dependencies.py | 6 ++++-- src/routers/mldcat_ap/dataset.py | 6 ++++-- src/routers/openml/estimation_procedure.py | 6 ++++-- src/routers/openml/evaluations.py | 6 ++++-- src/routers/openml/flows.py | 6 ++++-- src/routers/openml/qualities.py | 6 ++++-- src/routers/openml/runs.py | 6 ++++-- src/routers/openml/setups.py | 6 ++++-- src/routers/openml/study.py | 8 +++++--- src/routers/openml/tasks.py | 8 +++++--- src/routers/openml/tasktype.py | 8 +++++--- tests/conftest.py | 12 +++++++----- tests/database/flows_test.py | 5 ++++- tests/dependencies/fetch_user_test.py | 5 ++++- tests/routers/openml/dataset_tag_test.py | 7 +++++-- tests/routers/openml/datasets_features_test.py | 7 +++++-- tests/routers/openml/datasets_get_test.py | 7 +++++-- tests/routers/openml/datasets_list_datasets_test.py | 8 +++++--- tests/routers/openml/datasets_qualities_test.py | 5 ++++- tests/routers/openml/datasets_status_test.py | 7 +++++-- tests/routers/openml/estimation_procedure_test.py | 5 +++-- tests/routers/openml/evaluation_measure_test.py | 4 +++- tests/routers/openml/flows_exists_test.py | 9 ++++++--- tests/routers/openml/flows_get_test.py | 6 ++++-- tests/routers/openml/qualities_list_test.py | 7 +++++-- tests/routers/openml/runs_trace_test.py | 6 ++++-- tests/routers/openml/setups_get_test.py | 5 ++++- tests/routers/openml/setups_tag_test.py | 7 +++++-- tests/routers/openml/setups_untag_test.py | 7 +++++-- tests/routers/openml/study_attach_test.py | 7 +++++-- tests/routers/openml/study_get_test.py | 5 ++++- tests/routers/openml/study_post_test.py | 5 ++++- tests/routers/openml/task_get_test.py | 5 ++++- tests/routers/openml/task_list_test.py | 8 +++++--- tests/routers/openml/task_type_get_test.py | 5 ++++- tests/routers/openml/task_type_list_test.py | 4 +++- 49 files changed, 216 insertions(+), 97 deletions(-) diff --git a/src/core/access.py b/src/core/access.py index 558643f..871fb48 100644 --- a/src/core/access.py +++ b/src/core/access.py @@ -1,10 +1,11 @@ -from typing import Any - -from sqlalchemy.engine import Row +from typing import TYPE_CHECKING, Any from database.users import User from schemas.datasets.openml import Visibility +if TYPE_CHECKING: + from sqlalchemy.engine import Row + async def _user_has_access( dataset: Row[Any], diff --git a/src/core/errors.py b/src/core/errors.py index 69d7e0c..5fe5437 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -5,11 +5,14 @@ """ from http import HTTPStatus +from typing import TYPE_CHECKING -from fastapi import Request -from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +if TYPE_CHECKING: + from fastapi import Request + from fastapi.exceptions import RequestValidationError + # ============================================================================= # Base Exception # ============================================================================= diff --git a/src/core/formatting.py b/src/core/formatting.py index f954e81..406659f 100644 --- a/src/core/formatting.py +++ b/src/core/formatting.py @@ -1,10 +1,12 @@ import html - -from sqlalchemy.engine import Row +from typing import TYPE_CHECKING from config import load_routing_configuration from schemas.datasets.openml import DatasetFileFormat +if TYPE_CHECKING: + from sqlalchemy.engine import Row + def _str_to_bool(string: str) -> bool: if string.casefold() in ["true", "1", "yes", "y"]: diff --git a/src/core/logging.py b/src/core/logging.py index b35270d..6546f71 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -5,13 +5,16 @@ import uuid from collections.abc import Awaitable, Callable from pathlib import Path +from typing import TYPE_CHECKING from loguru import logger -from starlette.requests import Request -from starlette.responses import Response from config import load_configuration +if TYPE_CHECKING: + from starlette.requests import Request + from starlette.responses import Response + def setup_log_sinks(configuration_file: Path | None = None) -> None: """Configure loguru based on app configuration.""" diff --git a/src/database/datasets.py b/src/database/datasets.py index 22a9868..efa7745 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,11 +2,10 @@ import datetime from collections import defaultdict +from typing import TYPE_CHECKING from sqlalchemy import text -from sqlalchemy.engine import Row from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncConnection from database.exceptions import ( _DUPLICATE_ENTRY, @@ -16,6 +15,10 @@ ) from schemas.datasets.openml import Feature +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + async def get(id_: int, connection: AsyncConnection) -> Row | None: row = await connection.execute( diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 74faf59..382653f 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,12 +1,14 @@ from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection from core.formatting import _str_to_bool from schemas.datasets.openml import EstimationProcedure +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]: rows = await connection.execute( diff --git a/src/database/flows.py b/src/database/flows.py index 79bb6e5..ed022c4 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,8 +1,10 @@ from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: diff --git a/src/database/qualities.py b/src/database/qualities.py index 08647f4..9180b13 100644 --- a/src/database/qualities.py +++ b/src/database/qualities.py @@ -1,11 +1,14 @@ from collections import defaultdict from collections.abc import Iterable +from typing import TYPE_CHECKING from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from schemas.datasets.openml import Quality +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]: row = await connection.execute( diff --git a/src/database/runs.py b/src/database/runs.py index acf7a53..6eef0b3 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -1,10 +1,12 @@ """Database queries for run-related data.""" from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection async def exist(id_: int, expdb: AsyncConnection) -> bool: diff --git a/src/database/setups.py b/src/database/setups.py index 7449847..c8651d9 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -1,8 +1,12 @@ """All database operations that directly operate on setups.""" +from typing import TYPE_CHECKING + from sqlalchemy import text -from sqlalchemy.engine import Row, RowMapping -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.engine import Row, RowMapping + from sqlalchemy.ext.asyncio import AsyncConnection async def get(setup_id: int, connection: AsyncConnection) -> Row | None: diff --git a/src/database/studies.py b/src/database/studies.py index 286f798..3c7c8e9 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -1,14 +1,16 @@ import re from collections.abc import Sequence from datetime import UTC, datetime -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User from schemas.study import CreateStudy, StudyType +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None: row = await connection.execute( diff --git a/src/database/tasks.py b/src/database/tasks.py index e9670d2..ba4eecc 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,8 +1,10 @@ from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection async def get(id_: int, expdb: AsyncConnection) -> Row | None: diff --git a/src/database/users.py b/src/database/users.py index 0b09fb0..8a812d6 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -1,13 +1,15 @@ import dataclasses from enum import IntEnum -from typing import Annotated, Self +from typing import TYPE_CHECKING, Annotated, Self from pydantic import StringConstraints from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from config import load_configuration +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + # If `allow_test_api_keys` is set, the key may also be one of `normaluser`, # `normaluser2`, or `abc` (admin). api_key_pattern = r"^[0-9a-fA-F]{32}$" diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index dba12d8..ca4a965 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,15 +1,17 @@ from collections.abc import AsyncGenerator, AsyncIterator -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import Depends from loguru import logger from pydantic import BaseModel, Field -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.setup import expdb_database, user_database from database.users import APIKey, User +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def expdb_connection() -> AsyncIterator[AsyncConnection]: engine = expdb_database() diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index 998f940..2749664 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -5,10 +5,9 @@ """ import asyncio -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncConnection import config from core.errors import ServiceNotFoundError @@ -28,6 +27,9 @@ Quality, ) +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/mldcat_ap", tags=["MLDCAT-AP"]) _configuration = config.load_configuration() _server_url = ( diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py index b07c2c0..d8532dc 100644 --- a/src/routers/openml/estimation_procedure.py +++ b/src/routers/openml/estimation_procedure.py @@ -1,12 +1,14 @@ -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection from schemas.datasets.openml import EstimationProcedure +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/estimationprocedure", tags=["estimationprocedure"]) diff --git a/src/routers/openml/evaluations.py b/src/routers/openml/evaluations.py index f6650b3..f289131 100644 --- a/src/routers/openml/evaluations.py +++ b/src/routers/openml/evaluations.py @@ -1,11 +1,13 @@ -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/evaluationmeasure", tags=["evaluationmeasure"]) diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index b346793..b3c3f21 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -1,8 +1,7 @@ import asyncio -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from core.conversions import _str_to_num @@ -11,6 +10,9 @@ from routers.types import Identifier from schemas.flows import Flow, Parameter, Subflow +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/flows", tags=["flows"]) diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index f7684ca..dc7a3f5 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -1,7 +1,6 @@ -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -17,6 +16,9 @@ from routers.types import Identifier from schemas.datasets.openml import Quality +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/datasets", tags=["datasets"]) diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 641108f..4fc7948 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,9 +1,8 @@ """Endpoints for run-related data.""" -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.runs from core.errors import RunNotFoundError, RunTraceNotFoundError @@ -11,6 +10,9 @@ from routers.types import Identifier from schemas.runs import RunTrace, TraceIteration +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/run", tags=["run"]) diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py index 2d8bd57..dbca903 100644 --- a/src/routers/openml/setups.py +++ b/src/routers/openml/setups.py @@ -1,11 +1,10 @@ """All endpoints that relate to setups.""" import asyncio -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Body, Depends, Path from loguru import logger -from sqlalchemy.ext.asyncio import AsyncConnection import database.setups from core.errors import ( @@ -19,6 +18,9 @@ from routers.types import Identifier, SystemString64 from schemas.setups import SetupParameters, SetupResponse +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/setup", tags=["setup"]) diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 7c4ae72..ccdf9ff 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -1,10 +1,8 @@ -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Body, Depends from loguru import logger from pydantic import BaseModel -from sqlalchemy.engine import Row -from sqlalchemy.ext.asyncio import AsyncConnection import database.studies from core.errors import ( @@ -24,6 +22,10 @@ from schemas.core import Visibility from schemas.study import CreateStudy, Study, StudyStatus, StudyType +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/studies", tags=["studies"]) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 80cdacd..6627d79 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -2,13 +2,11 @@ import json import re from enum import StrEnum -from typing import Annotated, Any, cast +from typing import TYPE_CHECKING, Annotated, Any, cast import xmltodict from fastapi import APIRouter, Body, Depends from sqlalchemy import bindparam, text -from sqlalchemy.engine import RowMapping -from sqlalchemy.ext.asyncio import AsyncConnection import config import database.datasets @@ -24,6 +22,10 @@ ) from schemas.datasets.openml import Task +if TYPE_CHECKING: + from sqlalchemy.engine import RowMapping + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/tasks", tags=["tasks"]) type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5355e45..7cdba50 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,15 +1,17 @@ import json -from typing import Annotated, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast from fastapi import APIRouter, Depends -from sqlalchemy.engine import Row -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import TaskTypeNotFoundError from database.tasks import get_input_for_task_type, get_task_types from database.tasks import get_task_type as db_get_task_type from routers.dependencies import expdb_connection +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/tasktype", tags=["tasks"]) diff --git a/tests/conftest.py b/tests/conftest.py index 897e560..850a434 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,23 +2,25 @@ import json from collections.abc import AsyncIterator, Callable, Iterable, Iterator from pathlib import Path -from typing import Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import _pytest.mark import httpx import pytest -from _pytest.config import Config -from _pytest.nodes import Item from asgi_lifespan import LifespanManager -from fastapi import FastAPI from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from database.setup import expdb_database, user_database from main import create_api from routers.dependencies import expdb_connection, userdb_connection from tests.users import OWNER_USER +if TYPE_CHECKING: + from _pytest.config import Config + from _pytest.nodes import Item + from fastapi import FastAPI + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + PHP_API_URL = "http://php-api:80/api/v1/json" diff --git a/tests/database/flows_test.py b/tests/database/flows_test.py index a8b98d8..6e34577 100644 --- a/tests/database/flows_test.py +++ b/tests/database/flows_test.py @@ -1,8 +1,11 @@ -from sqlalchemy.ext.asyncio import AsyncConnection +from typing import TYPE_CHECKING import database.flows from tests.conftest import Flow +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_database_flow_exists(flow: Flow, expdb_test: AsyncConnection) -> None: retrieved_flow = await database.flows.get_by_name(flow.name, flow.external_version, expdb_test) diff --git a/tests/dependencies/fetch_user_test.py b/tests/dependencies/fetch_user_test.py index 116bbdd..c9bfa08 100644 --- a/tests/dependencies/fetch_user_test.py +++ b/tests/dependencies/fetch_user_test.py @@ -1,13 +1,16 @@ from contextlib import aclosing +from typing import TYPE_CHECKING import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.users import User from routers.dependencies import fetch_user, fetch_user_or_raise from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + @pytest.mark.parametrize( ("api_key", "user"), diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 79c982f..5bcfb4a 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,9 +1,8 @@ import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.conversions import nested_remove_single_element_list from core.errors import DatasetNotFoundError, TagAlreadyExistsError @@ -13,6 +12,10 @@ from tests import constants from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + @pytest.mark.parametrize( "key", diff --git a/tests/routers/openml/datasets_features_test.py b/tests/routers/openml/datasets_features_test.py index 1fd8985..3616659 100644 --- a/tests/routers/openml/datasets_features_test.py +++ b/tests/routers/openml/datasets_features_test.py @@ -3,16 +3,19 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import DatasetNoAccessError, DatasetNotFoundError, DatasetProcessingError from database.users import User from routers.openml.datasets import get_dataset_features from tests.users import ADMIN_USER, DATASET_130_OWNER +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_get_features_via_api(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/features/4") diff --git a/tests/routers/openml/datasets_get_test.py b/tests/routers/openml/datasets_get_test.py index 6bcd51b..8772d05 100644 --- a/tests/routers/openml/datasets_get_test.py +++ b/tests/routers/openml/datasets_get_test.py @@ -4,11 +4,10 @@ import json import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection import tests.constants from core.errors import DatasetNoAccessError, DatasetNotFoundError @@ -17,6 +16,10 @@ from schemas.datasets.openml import DatasetMetadata from tests.users import ADMIN_USER, DATASET_130_OWNER, NO_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_get_dataset_via_api(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/1") diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index 7460fb5..dc0523b 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -1,13 +1,11 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any -import httpx import hypothesis import pytest from hypothesis import given from hypothesis import strategies as st -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import NoResultsError from database.users import User @@ -16,6 +14,10 @@ from tests import constants from tests.users import ADMIN_USER, DATASET_130_OWNER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_list_route(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/list/") diff --git a/tests/routers/openml/datasets_qualities_test.py b/tests/routers/openml/datasets_qualities_test.py index fb3559c..d67e54a 100644 --- a/tests/routers/openml/datasets_qualities_test.py +++ b/tests/routers/openml/datasets_qualities_test.py @@ -1,11 +1,14 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff -import httpx import pytest +if TYPE_CHECKING: + import httpx + async def test_get_quality(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/qualities/1") diff --git a/tests/routers/openml/datasets_status_test.py b/tests/routers/openml/datasets_status_test.py index 1e2271f..adc5892 100644 --- a/tests/routers/openml/datasets_status_test.py +++ b/tests/routers/openml/datasets_status_test.py @@ -1,10 +1,9 @@ """Tests for the POST /datasets/status/update endpoint.""" from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import DatasetAdminOnlyError, DatasetNotOwnedError from routers.openml.datasets import update_dataset_status @@ -12,6 +11,10 @@ from tests import constants from tests.users import ADMIN_USER, SOME_USER +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_update_status_via_api(py_api: httpx.AsyncClient) -> None: response = await py_api.post( diff --git a/tests/routers/openml/estimation_procedure_test.py b/tests/routers/openml/estimation_procedure_test.py index a05b34d..dafa35c 100644 --- a/tests/routers/openml/estimation_procedure_test.py +++ b/tests/routers/openml/estimation_procedure_test.py @@ -1,8 +1,9 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any -import httpx +if TYPE_CHECKING: + import httpx async def test_estimation_procedure_list(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/evaluation_measure_test.py b/tests/routers/openml/evaluation_measure_test.py index 2df2483..feecb6b 100644 --- a/tests/routers/openml/evaluation_measure_test.py +++ b/tests/routers/openml/evaluation_measure_test.py @@ -1,7 +1,9 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx +if TYPE_CHECKING: + import httpx async def test_evaluationmeasure_list(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/flows_exists_test.py b/tests/routers/openml/flows_exists_test.py index bb09edd..988402f 100644 --- a/tests/routers/openml/flows_exists_test.py +++ b/tests/routers/openml/flows_exists_test.py @@ -1,16 +1,19 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from pytest_mock import MockerFixture -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import FlowNotFoundError from routers.openml.flows import flow_exists from tests.conftest import Flow +if TYPE_CHECKING: + import httpx + from pytest_mock import MockerFixture + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None: response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") diff --git a/tests/routers/openml/flows_get_test.py b/tests/routers/openml/flows_get_test.py index 17bbfcc..862bb70 100644 --- a/tests/routers/openml/flows_get_test.py +++ b/tests/routers/openml/flows_get_test.py @@ -1,9 +1,8 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any import deepdiff.diff -import httpx import pytest from core.conversions import ( @@ -11,6 +10,9 @@ nested_str_to_num, ) +if TYPE_CHECKING: + import httpx + async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/flows/1") diff --git a/tests/routers/openml/qualities_list_test.py b/tests/routers/openml/qualities_list_test.py index 8eb51a5..84cb94b 100644 --- a/tests/routers/openml/qualities_list_test.py +++ b/tests/routers/openml/qualities_list_test.py @@ -1,10 +1,13 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection async def _remove_quality_from_database(quality_name: str, expdb_test: AsyncConnection) -> None: diff --git a/tests/routers/openml/runs_trace_test.py b/tests/routers/openml/runs_trace_test.py index 11fd10a..a9e664c 100644 --- a/tests/routers/openml/runs_trace_test.py +++ b/tests/routers/openml/runs_trace_test.py @@ -2,15 +2,17 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any import deepdiff -import httpx import pytest from core.conversions import nested_num_to_str from core.errors import RunNotFoundError, RunTraceNotFoundError +if TYPE_CHECKING: + import httpx + @pytest.mark.parametrize("run_id", [34]) async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/setups_get_test.py b/tests/routers/openml/setups_get_test.py index 6762714..646fa0c 100644 --- a/tests/routers/openml/setups_get_test.py +++ b/tests/routers/openml/setups_get_test.py @@ -1,12 +1,15 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from core.conversions import nested_remove_values, nested_str_to_num +if TYPE_CHECKING: + import httpx + async def test_get_setup_unknown(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/setup/999999") diff --git a/tests/routers/openml/setups_tag_test.py b/tests/routers/openml/setups_tag_test.py index ad9659f..f4bc516 100644 --- a/tests/routers/openml/setups_tag_test.py +++ b/tests/routers/openml/setups_tag_test.py @@ -3,16 +3,19 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import SetupNotFoundError, TagAlreadyExistsError from routers.openml.setups import tag_setup from tests.users import SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_setup_tag_missing_auth(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/setup/tag", json={"setup_id": 1, "tag": "test_tag"}) diff --git a/tests/routers/openml/setups_untag_test.py b/tests/routers/openml/setups_untag_test.py index 55d4b28..1491fd1 100644 --- a/tests/routers/openml/setups_untag_test.py +++ b/tests/routers/openml/setups_untag_test.py @@ -3,16 +3,19 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import SetupNotFoundError, TagNotFoundError, TagNotOwnedError from routers.openml.setups import untag_setup from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_setup_untag_missing_auth(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/setup/untag", json={"setup_id": 1, "tag": "test_tag"}) diff --git a/tests/routers/openml/study_attach_test.py b/tests/routers/openml/study_attach_test.py index 2da1b8f..dcb6870 100644 --- a/tests/routers/openml/study_attach_test.py +++ b/tests/routers/openml/study_attach_test.py @@ -1,14 +1,17 @@ from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import StudyConflictError from schemas.study import StudyType from tests.users import ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def _attach_tasks_to_study( study_id: int, diff --git a/tests/routers/openml/study_get_test.py b/tests/routers/openml/study_get_test.py index 1ef2cff..0762633 100644 --- a/tests/routers/openml/study_get_test.py +++ b/tests/routers/openml/study_get_test.py @@ -1,11 +1,14 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff -import httpx from core.conversions import nested_num_to_str, nested_remove_values +if TYPE_CHECKING: + import httpx + async def test_get_task_study_by_id(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/studies/1") diff --git a/tests/routers/openml/study_post_test.py b/tests/routers/openml/study_post_test.py index 0cb00fd..3c3c5e9 100644 --- a/tests/routers/openml/study_post_test.py +++ b/tests/routers/openml/study_post_test.py @@ -1,11 +1,14 @@ from datetime import UTC, datetime from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from tests.users import ApiKey +if TYPE_CHECKING: + import httpx + @pytest.mark.mut async def test_create_task_study(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/task_get_test.py b/tests/routers/openml/task_get_test.py index 955a7b8..7225061 100644 --- a/tests/routers/openml/task_get_test.py +++ b/tests/routers/openml/task_get_test.py @@ -1,8 +1,8 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff -import httpx import pytest from core.conversions import ( @@ -11,6 +11,9 @@ nested_remove_values, ) +if TYPE_CHECKING: + import httpx + async def test_get_task(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/tasks/59") diff --git a/tests/routers/openml/task_list_test.py b/tests/routers/openml/task_list_test.py index 45404d1..7a3f562 100644 --- a/tests/routers/openml/task_list_test.py +++ b/tests/routers/openml/task_list_test.py @@ -1,17 +1,19 @@ import asyncio from http import HTTPStatus -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import deepdiff -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.conversions import nested_remove_single_element_list from core.errors import NoResultsError from routers.dependencies import LIMIT_MAX, Pagination from routers.openml.tasks import TaskStatusFilter, list_tasks +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: """Default call returns active tasks with correct shape.""" diff --git a/tests/routers/openml/task_type_get_test.py b/tests/routers/openml/task_type_get_test.py index 61bd0c9..b186da1 100644 --- a/tests/routers/openml/task_type_get_test.py +++ b/tests/routers/openml/task_type_get_test.py @@ -1,12 +1,15 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff.diff -import httpx import pytest from core.errors import TaskTypeNotFoundError +if TYPE_CHECKING: + import httpx + @pytest.mark.parametrize( "ttype_id", diff --git a/tests/routers/openml/task_type_list_test.py b/tests/routers/openml/task_type_list_test.py index 871def3..ee61e1b 100644 --- a/tests/routers/openml/task_type_list_test.py +++ b/tests/routers/openml/task_type_list_test.py @@ -1,6 +1,8 @@ import asyncio +from typing import TYPE_CHECKING -import httpx +if TYPE_CHECKING: + import httpx async def test_list_task_type(py_api: httpx.AsyncClient, php_api: httpx.AsyncClient) -> None: From b436b7b8f34c1654d6c69e1bebc5ce82c2f3cb66 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 15:15:18 +0200 Subject: [PATCH 12/16] Fix bug when deleting last tag of a dataset --- src/routers/openml/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index ba3e129..de0e52d 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -99,7 +99,7 @@ async def untag_dataset_like_php( ) -> dict[Literal["data_untag"], UntagInfo]: await untag_dataset(data_id, tag, user, expdb_db) tags = await database.datasets.get_tags_for(id_=data_id, connection=expdb_db) - return_tags = tags if len(tags) > 1 else tags[0] + return_tags = tags[0] if len(tags) == 1 else tags return {"data_untag": {"id": str(data_id), "tag": return_tags}} From 00d39cf4bd5d2f7b161aabd70a83c617f6331c70 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 15:21:13 +0200 Subject: [PATCH 13/16] Undo some type checking imports because they are used by collection --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 850a434..0f9b49f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ import _pytest.mark import httpx import pytest +from _pytest.config import Config # noqa: TC002 used during collection by Pytest +from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest from asgi_lifespan import LifespanManager from sqlalchemy import text @@ -16,8 +18,6 @@ from tests.users import OWNER_USER if TYPE_CHECKING: - from _pytest.config import Config - from _pytest.nodes import Item from fastapi import FastAPI from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine From f09ba290b487d515b75ef2355346ea6af7600c00 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 15:28:54 +0200 Subject: [PATCH 14/16] Update behavior if there is no tag remaining --- src/routers/openml/datasets.py | 16 ++++++++++------ tests/routers/openml/dataset_untag_test.py | 12 +++++++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index de0e52d..12b796f 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import StrEnum from http import HTTPStatus -from typing import TYPE_CHECKING, Annotated, Any, Literal, NamedTuple, TypedDict +from typing import TYPE_CHECKING, Annotated, Any, Literal, NamedTuple, NotRequired, TypedDict from fastapi import APIRouter, Body, Depends, Query from loguru import logger @@ -85,9 +85,9 @@ async def tag_dataset( } -class UntagInfo(TypedDict): +class TagInfo(TypedDict): id: str - tag: SystemString64 | list[SystemString64] + tag: NotRequired[SystemString64 | list[SystemString64]] @router.post(path="/untag", deprecated=True) @@ -96,11 +96,15 @@ async def untag_dataset_like_php( tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], -) -> dict[Literal["data_untag"], UntagInfo]: +) -> dict[Literal["data_untag"], TagInfo]: await untag_dataset(data_id, tag, user, expdb_db) tags = await database.datasets.get_tags_for(id_=data_id, connection=expdb_db) - return_tags = tags[0] if len(tags) == 1 else tags - return {"data_untag": {"id": str(data_id), "tag": return_tags}} + tag_info: TagInfo = {"id": str(data_id)} + if len(tags) == 1: + tag_info["tag"] = tags[0] + elif tags: + tag_info["tag"] = tags + return {"data_untag": tag_info} @router.delete(path="/{identifier}/tag", status_code=HTTPStatus.NO_CONTENT) diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py index 4d26d54..cfaaf24 100644 --- a/tests/routers/openml/dataset_untag_test.py +++ b/tests/routers/openml/dataset_untag_test.py @@ -114,20 +114,26 @@ async def test_dataset_untag_dataset_does_not_exist(expdb_test: AsyncConnection) @pytest.mark.mut +@pytest.mark.parametrize("existing_tags", [[], ["bar"], ["bar", "bazz"]]) async def test_dataset_untag_success_is_identical( + existing_tags: list[str], py_api: httpx.AsyncClient, php_api: httpx.AsyncClient, temporary_tags: Callable[..., AbstractAsyncContextManager[None]], ) -> None: - dataset_id = 1 + dataset_id = 101 # The first dataset without a pre-existing tag tag = "foo" - async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id, persist=True): + async with temporary_tags( + table="dataset_tag", tags=[tag, *existing_tags], identifier=dataset_id, persist=True + ): php_response = await php_api.post( f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} ) - async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id): + async with temporary_tags( + table="dataset_tag", tags=[tag, *existing_tags], identifier=dataset_id + ): py_response = await py_api.post( f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} ) From 8c90e606928432b05435f3ea7f6fc7dc7c488ac1 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 15:45:10 +0200 Subject: [PATCH 15/16] Fix erroneous tests --- tests/routers/openml/datasets_get_test.py | 6 +++--- tests/routers/openml/datasets_list_datasets_test.py | 2 +- tests/routers/openml/setups_tag_test.py | 8 +++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/routers/openml/datasets_get_test.py b/tests/routers/openml/datasets_get_test.py index 8772d05..3ef034e 100644 --- a/tests/routers/openml/datasets_get_test.py +++ b/tests/routers/openml/datasets_get_test.py @@ -249,14 +249,14 @@ async def test_dataset_not_found_is_identical( dataset_id = 9_999_999 py_response, php_response = await asyncio.gather( py_api.get(f"/datasets/{dataset_id}"), - php_api.get(f"/datasets/{dataset_id}"), + php_api.get(f"/data/{dataset_id}"), ) assert py_response.status_code == HTTPStatus.NOT_FOUND assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert py_response.json()["detail"] == f"Dataset {dataset_id} not found." - assert php_response.json()["error"]["message"] == "Dataset not found." + assert py_response.json()["detail"] == f"No dataset with id {dataset_id} found." + assert php_response.json()["error"]["message"] == "Unknown dataset" async def test_private_dataset_no_user_no_access( diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index dc0523b..868e9ee 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -2,6 +2,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any +import httpx # noqa: TC002 is used in a function signature inspected at runtime import hypothesis import pytest from hypothesis import given @@ -15,7 +16,6 @@ from tests.users import ADMIN_USER, DATASET_130_OWNER, SOME_USER, ApiKey if TYPE_CHECKING: - import httpx from sqlalchemy.ext.asyncio import AsyncConnection diff --git a/tests/routers/openml/setups_tag_test.py b/tests/routers/openml/setups_tag_test.py index f4bc516..c674ca2 100644 --- a/tests/routers/openml/setups_tag_test.py +++ b/tests/routers/openml/setups_tag_test.py @@ -114,7 +114,9 @@ async def test_setup_tag_response_is_identical_when_tag_doesnt_exist( # noqa: P setup_id = 1 tag = "totally_new_tag_for_migration_testing" - async with temporary_tags(tags=other_tags, setup_id=setup_id, persist=True): + async with temporary_tags( + table="setup_tag", tags=other_tags, identifier=setup_id, persist=True + ): php_response = await php_api.post( "/setup/tag", data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, @@ -126,7 +128,7 @@ async def test_setup_tag_response_is_identical_when_tag_doesnt_exist( # noqa: P ) await expdb_test.commit() - async with temporary_tags(tags=other_tags, setup_id=setup_id): + async with temporary_tags(table="setup_tag", tags=other_tags, identifier=setup_id): py_response = await py_api.post( f"/setup/tag?api_key={api_key}", json={"setup_id": setup_id, "tag": tag}, @@ -185,7 +187,7 @@ async def test_setup_tag_response_is_identical_tag_already_exists( tag = "totally_new_tag_for_migration_testing" api_key = ApiKey.SOME_USER - async with temporary_tags(tags=[tag], setup_id=setup_id, persist=True): + async with temporary_tags(table="setup_tag", tags=[tag], identifier=setup_id, persist=True): # Both APIs can be tested in parallel since the tag is already persisted php_response, py_response = await asyncio.gather( php_api.post( From be5bc7406ebe037127589d2917dfe5aa36b9cb22 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 4 May 2026 15:49:36 +0200 Subject: [PATCH 16/16] remove now unused table parameter --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0f9b49f..01ac8c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -187,7 +187,6 @@ async def _temporary_tags( ( f"INSERT INTO {table}(`id`,`tag`,`uploader`) VALUES (:identifier, :tag, :user_id);", # noqa: S608 # No user provided values { - "table": table, "identifier": identifier, "tag": tag, "user_id": OWNER_USER.user_id,