diff --git a/src/core/errors.py b/src/core/errors.py index 5f874894..3f53364a 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -219,6 +219,24 @@ class TagAlreadyExistsError(ProblemDetailError): _default_code = 473 +class TagNotFoundError(ProblemDetailError): + """Raised when trying to remove or retrieve a tag that does not exist.""" + + uri = "https://openml.org/problems/tag-not-found" + title = "Tag Not Found" + _default_status_code = HTTPStatus.NOT_FOUND + _default_code = 475 + + +class TagNotOwnedError(ProblemDetailError): + """Raised when trying to remove a tag that was created by someone else.""" + + uri = "https://openml.org/problems/tag-not-owned" + title = "Tag Not Owned" + _default_status_code = HTTPStatus.FORBIDDEN + _default_code = 476 + + # ============================================================================= # Search/List Errors # ============================================================================= @@ -329,6 +347,20 @@ class FlowNotFoundError(ProblemDetailError): _default_status_code = HTTPStatus.NOT_FOUND +# ============================================================================= +# Setup Errors +# ============================================================================= + + +class SetupNotFoundError(ProblemDetailError): + """Raised when a setup cannot be found.""" + + uri = "https://openml.org/problems/setup-not-found" + title = "Setup Not Found" + _default_status_code = HTTPStatus.NOT_FOUND + _default_code = 472 + + # ============================================================================= # Service Errors # ============================================================================= diff --git a/src/database/setups.py b/src/database/setups.py new file mode 100644 index 00000000..866a7de3 --- /dev/null +++ b/src/database/setups.py @@ -0,0 +1,48 @@ +"""All database operations that directly operate on setups.""" + +from sqlalchemy import text +from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection + + +async def get(setup_id: int, connection: AsyncConnection) -> Row | None: + """Get the setup with id `setup_id` from the database.""" + row = await connection.execute( + text( + """ + SELECT * + FROM algorithm_setup + WHERE sid = :setup_id + """, + ), + parameters={"setup_id": setup_id}, + ) + return row.first() + + +async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]: + """Get all tags for setup with `setup_id` from the database.""" + rows = await connection.execute( + text( + """ + SELECT * + FROM setup_tag + WHERE id = :setup_id + """, + ), + parameters={"setup_id": setup_id}, + ) + return list(rows.all()) + + +async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None: + """Remove tag `tag` from setup with id `setup_id`.""" + await connection.execute( + text( + """ + DELETE FROM setup_tag + WHERE id = :setup_id AND tag = :tag + """, + ), + parameters={"setup_id": setup_id, "tag": tag}, + ) diff --git a/src/main.py b/src/main.py index e2a16319..76a52ad3 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.setups import router as setup_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -68,6 +69,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(setup_router) return app diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index f73ac995..590ae36b 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncConnection +from core.errors import AuthenticationFailedError from database.setup import expdb_database, user_database from database.users import APIKey, User @@ -28,6 +29,15 @@ async def fetch_user( return await User.fetch(api_key, user_data) if api_key and user_data else None +def fetch_user_or_raise( + user: Annotated[User | None, Depends(fetch_user)] = None, +) -> User: + if user is None: + msg = "Authentication failed" + raise AuthenticationFailedError(msg) + return user + + class Pagination(BaseModel): offset: int = 0 limit: int = 100 diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index f2b394a0..d86ed848 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -12,7 +12,6 @@ import database.qualities from core.access import _user_has_access from core.errors import ( - AuthenticationFailedError, AuthenticationRequiredError, DatasetAdminOnlyError, DatasetNoAccessError, @@ -33,7 +32,13 @@ _format_parquet_url, ) from database.users import User, UserGroup -from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection +from routers.dependencies import ( + Pagination, + expdb_connection, + fetch_user, + fetch_user_or_raise, + userdb_connection, +) from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType @@ -46,7 +51,7 @@ async def tag_dataset( data_id: Annotated[int, Body()], tag: Annotated[str, SystemString64], - user: Annotated[User | None, Depends(fetch_user)] = None, + user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[str, dict[str, Any]]: assert expdb_db is not None # noqa: S101 @@ -55,10 +60,6 @@ async def tag_dataset( msg = f"Dataset {data_id} already tagged with {tag!r}." raise TagAlreadyExistsError(msg) - if user is None: - msg = "Authentication failed." - raise AuthenticationFailedError(msg) - await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) return { "data_tag": {"id": str(data_id), "tag": [*tags, tag]}, diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py new file mode 100644 index 00000000..c7823f8e --- /dev/null +++ b/src/routers/openml/setups.py @@ -0,0 +1,44 @@ +"""All endpoints that relate to setups.""" + +from typing import Annotated + +from fastapi import APIRouter, Body, Depends +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.setups +from core.errors import SetupNotFoundError, TagNotFoundError, TagNotOwnedError +from database.users import User, UserGroup +from routers.dependencies import expdb_connection, fetch_user_or_raise +from routers.types import SystemString64 + +router = APIRouter(prefix="/setup", tags=["setup"]) + + +@router.post(path="/untag") +async def untag_setup( + setup_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, str | list[str]]]: + """Remove tag `tag` from setup with id `setup_id`.""" + if not await database.setups.get(setup_id, expdb_db): + msg = f"Setup {setup_id} not found." + raise SetupNotFoundError(msg) + + setup_tags = await database.setups.get_tags(setup_id, expdb_db) + matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None) + + if not matched_tag_row: + msg = f"Setup {setup_id} does not have tag {tag!r}." + raise TagNotFoundError(msg) + + if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups(): + msg = ( + f"You may not remove tag {tag!r} of setup {setup_id} because it was not created by you." + ) + raise TagNotOwnedError(msg) + + await database.setups.untag(setup_id, matched_tag_row.tag, expdb_db) + remaining_tags = [t.tag.casefold() for t in setup_tags if t != matched_tag_row] + return {"setup_untag": {"id": str(setup_id), "tag": remaining_tags}} diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 646ac0c3..25042c89 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -83,7 +83,7 @@ async def test_dataset_tag_invalid_tag_is_rejected( py_api: httpx.AsyncClient, ) -> None: new = await py_api.post( - f"/datasets/tag?api_key{ApiKey.ADMIN}", + f"/datasets/tag?api_key={ApiKey.ADMIN}", json={"data_id": 1, "tag": tag}, ) diff --git a/tests/routers/openml/migration/setups_migration_test.py b/tests/routers/openml/migration/setups_migration_test.py new file mode 100644 index 00000000..361dfd78 --- /dev/null +++ b/tests/routers/openml/migration/setups_migration_test.py @@ -0,0 +1,149 @@ +import contextlib +import re +from collections.abc import AsyncGenerator, Iterable +from http import HTTPStatus + +import httpx +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection + +from tests.users import OWNER_USER, ApiKey + + +@pytest.mark.parametrize( + "api_key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["Administrator", "non-owner", "tag owner"], +) +@pytest.mark.parametrize( + "other_tags", + [[], ["some_other_tag"], ["foo_some_other_tag", "bar_some_other_tag"]], + ids=["none", "one tag", "two tags"], +) +async def test_setup_untag_response_is_identical_when_tag_exists( + api_key: str, + other_tags: list[str], + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, + expdb_test: AsyncConnection, +) -> None: + setup_id = 1 + tag = "totally_new_tag_for_migration_testing" + + @contextlib.asynccontextmanager + async def temporary_tags( + tags: Iterable[str], setup_id: int, *, persist: bool = False + ) -> AsyncGenerator[None]: + for tag in tags: + await expdb_test.execute( + text( + "INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);" # noqa: E501 + ), + parameters={"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id}, + ) + if persist: + await expdb_test.commit() + yield + for tag in tags: + await expdb_test.execute( + text("DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag"), + parameters={"setup_id": setup_id, "tag": tag}, + ) + if persist: + await expdb_test.commit() + + all_tags = [tag, *other_tags] + async with temporary_tags(tags=all_tags, setup_id=setup_id, persist=True): + original = await php_api.post( + "/setup/untag", + data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, + ) + + # expdb_test transaction shared with Python API, + # no commit needed and rolled back at the end of the test + async with temporary_tags(tags=all_tags, setup_id=setup_id): + new = await py_api.post( + f"/setup/untag?api_key={api_key}", + json={"setup_id": setup_id, "tag": tag}, + ) + + if new.status_code == HTTPStatus.OK: + assert original.status_code == new.status_code + original_untag = original.json()["setup_untag"] + new_untag = new.json()["setup_untag"] + assert original_untag["id"] == new_untag["id"] + if tags := original_untag.get("tag"): + if isinstance(tags, str): + assert tags == new_untag["tag"][0] + else: + assert tags == new_untag["tag"] + else: + assert new_untag["tag"] == [] + return + + code, message = original.json()["error"].values() + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.FORBIDDEN + assert code == new.json()["code"] + assert message == "Tag is not owned by you" + assert re.match( + r"You may not remove tag \S+ of setup \d+ because it was not created by you.", + new.json()["detail"], + ) + + +async def test_setup_untag_response_is_identical_setup_doesnt_exist( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + setup_id = 999999 + tag = "totally_new_tag_for_migration_testing" + api_key = ApiKey.SOME_USER + + original = await php_api.post( + "/setup/untag", + data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, + ) + + new = await py_api.post( + f"/setup/untag?api_key={api_key}", + json={"setup_id": setup_id, "tag": tag}, + ) + + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.NOT_FOUND + assert original.json()["error"]["message"] == "Entity not found." + assert original.json()["error"]["code"] == new.json()["code"] + assert re.match( + r"Setup \d+ not found.", + new.json()["detail"], + ) + + +async def test_setup_untag_response_is_identical_tag_doesnt_exist( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + setup_id = 1 + tag = "totally_new_tag_for_migration_testing" + api_key = ApiKey.SOME_USER + + original = await php_api.post( + "/setup/untag", + data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, + ) + + new = await py_api.post( + f"/setup/untag?api_key={api_key}", + json={"setup_id": setup_id, "tag": tag}, + ) + + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.NOT_FOUND + assert original.json()["error"]["code"] == new.json()["code"] + assert original.json()["error"]["message"] == "Tag not found." + assert re.match( + r"Setup \d+ does not have tag '\S+'.", + new.json()["detail"], + ) diff --git a/tests/routers/openml/setups_test.py b/tests/routers/openml/setups_test.py new file mode 100644 index 00000000..5ea8b515 --- /dev/null +++ b/tests/routers/openml/setups_test.py @@ -0,0 +1,85 @@ +import re +from http import HTTPStatus + +import httpx +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection + +from tests.users import ApiKey + + +async def test_setup_untag_missing_auth(py_api: httpx.AsyncClient) -> None: + response = await py_api.post("/setup/untag", json={"setup_id": 1, "tag": "test_tag"}) + assert response.status_code == HTTPStatus.UNAUTHORIZED + assert response.json()["code"] == "103" + assert response.json()["detail"] == "Authentication failed" + + +async def test_setup_untag_unknown_setup(py_api: httpx.AsyncClient) -> None: + response = await py_api.post( + f"/setup/untag?api_key={ApiKey.SOME_USER}", + json={"setup_id": 999999, "tag": "test_tag"}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert re.match( + r"Setup \d+ not found.", + response.json()["detail"], + ) + + +async def test_setup_untag_tag_not_found(py_api: httpx.AsyncClient) -> None: + response = await py_api.post( + f"/setup/untag?api_key={ApiKey.SOME_USER}", + json={"setup_id": 1, "tag": "non_existent_tag_12345"}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert re.match( + r"Setup \d+ does not have tag '\S+'.", + response.json()["detail"], + ) + + +@pytest.mark.mut +async def test_setup_untag_not_owned_by_you( + py_api: httpx.AsyncClient, expdb_test: AsyncConnection +) -> None: + await expdb_test.execute( + text("INSERT INTO setup_tag (id, tag, uploader) VALUES (1, 'test_unit_tag_123', 2);") + ) + response = await py_api.post( + f"/setup/untag?api_key={ApiKey.OWNER_USER}", + json={"setup_id": 1, "tag": "test_unit_tag_123"}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN + assert re.match( + r"You may not remove tag '\S+' of setup \d+ because it was not created by you.", + response.json()["detail"], + ) + + +@pytest.mark.mut +@pytest.mark.parametrize( + "api_key", + [ApiKey.SOME_USER, ApiKey.ADMIN], + ids=["Owner", "Administrator"], +) +async def test_setup_untag_success( + api_key: str, py_api: httpx.AsyncClient, expdb_test: AsyncConnection +) -> None: + await expdb_test.execute( + text("INSERT INTO setup_tag (id, tag, uploader) VALUES (1, 'test_success_tag', 2)") + ) + + response = await py_api.post( + f"/setup/untag?api_key={api_key}", + json={"setup_id": 1, "tag": "test_success_tag"}, + ) + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"setup_untag": {"id": "1", "tag": []}} + + rows = await expdb_test.execute( + text("SELECT * FROM setup_tag WHERE id = 1 AND tag = 'test_success_tag'") + ) + assert len(rows.all()) == 0