Skip to content
32 changes: 32 additions & 0 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,24 @@ class TagAlreadyExistsError(ProblemDetailError):
_default_code = 473


class TagNotFoundError(ProblemDetailError):
"""Raised when trying to remove or retrieve a tag that does not exist."""

uri = "https://openml.org/problems/tag-not-found"
title = "Tag Not Found"
_default_status_code = HTTPStatus.NOT_FOUND
_default_code = 475


class TagNotOwnedError(ProblemDetailError):
"""Raised when trying to remove a tag that was created by someone else."""

uri = "https://openml.org/problems/tag-not-owned"
title = "Tag Not Owned"
_default_status_code = HTTPStatus.FORBIDDEN
_default_code = 476


# =============================================================================
# Search/List Errors
# =============================================================================
Expand Down Expand Up @@ -329,6 +347,20 @@ class FlowNotFoundError(ProblemDetailError):
_default_status_code = HTTPStatus.NOT_FOUND


# =============================================================================
# Setup Errors
# =============================================================================


class SetupNotFoundError(ProblemDetailError):
"""Raised when a setup cannot be found."""

uri = "https://openml.org/problems/setup-not-found"
title = "Setup Not Found"
_default_status_code = HTTPStatus.NOT_FOUND
_default_code = 472


# =============================================================================
# Service Errors
# =============================================================================
Expand Down
48 changes: 48 additions & 0 deletions src/database/setups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""All database operations that directly operate on setups."""

from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
"""Get the setup with id `setup_id` from the database."""
row = await connection.execute(
text(
"""
SELECT *
FROM algorithm_setup
WHERE sid = :setup_id
""",
),
parameters={"setup_id": setup_id},
)
return row.first()


async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
"""Get all tags for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
"""
SELECT *
FROM setup_tag
WHERE id = :setup_id
""",
),
parameters={"setup_id": setup_id},
)
return list(rows.all())


async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
"""Remove tag `tag` from setup with id `setup_id`."""
await connection.execute(
text(
"""
DELETE FROM setup_tag
WHERE id = :setup_id AND tag = :tag
""",
),
parameters={"setup_id": setup_id, "tag": tag},
)
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from routers.openml.evaluations import router as evaluationmeasures_router
from routers.openml.flows import router as flows_router
from routers.openml.qualities import router as qualities_router
from routers.openml.setups import router as setup_router
from routers.openml.study import router as study_router
from routers.openml.tasks import router as task_router
from routers.openml.tasktype import router as ttype_router
Expand Down Expand Up @@ -68,6 +69,7 @@ def create_api() -> FastAPI:
app.include_router(task_router)
app.include_router(flows_router)
app.include_router(study_router)
app.include_router(setup_router)
return app


Expand Down
10 changes: 10 additions & 0 deletions src/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import AuthenticationFailedError
from database.setup import expdb_database, user_database
from database.users import APIKey, User

Expand All @@ -28,6 +29,15 @@ async def fetch_user(
return await User.fetch(api_key, user_data) if api_key and user_data else None


def fetch_user_or_raise(
user: Annotated[User | None, Depends(fetch_user)] = None,
) -> User:
if user is None:
msg = "Authentication failed"
raise AuthenticationFailedError(msg)
return user


class Pagination(BaseModel):
offset: int = 0
limit: int = 100
15 changes: 8 additions & 7 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import database.qualities
from core.access import _user_has_access
from core.errors import (
AuthenticationFailedError,
AuthenticationRequiredError,
DatasetAdminOnlyError,
DatasetNoAccessError,
Expand All @@ -33,7 +32,13 @@
_format_parquet_url,
)
from database.users import User, UserGroup
from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection
from routers.dependencies import (
Pagination,
expdb_connection,
fetch_user,
fetch_user_or_raise,
userdb_connection,
)
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType

Expand All @@ -46,7 +51,7 @@
async def tag_dataset(
data_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User | None, Depends(fetch_user)] = None,
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> dict[str, dict[str, Any]]:
assert expdb_db is not None # noqa: S101
Expand All @@ -55,10 +60,6 @@ async def tag_dataset(
msg = f"Dataset {data_id} already tagged with {tag!r}."
raise TagAlreadyExistsError(msg)

if user is None:
msg = "Authentication failed."
raise AuthenticationFailedError(msg)

await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
return {
"data_tag": {"id": str(data_id), "tag": [*tags, tag]},
Expand Down
44 changes: 44 additions & 0 deletions src/routers/openml/setups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""All endpoints that relate to setups."""

from typing import Annotated

from fastapi import APIRouter, Body, Depends
from sqlalchemy.ext.asyncio import AsyncConnection

import database.setups
from core.errors import SetupNotFoundError, TagNotFoundError, TagNotOwnedError
from database.users import User, UserGroup
from routers.dependencies import expdb_connection, fetch_user_or_raise
from routers.types import SystemString64

router = APIRouter(prefix="/setup", tags=["setup"])


@router.post(path="/untag")
async def untag_setup(
setup_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
"""Remove tag `tag` from setup with id `setup_id`."""
if not await database.setups.get(setup_id, expdb_db):
msg = f"Setup {setup_id} not found."
raise SetupNotFoundError(msg)

setup_tags = await database.setups.get_tags(setup_id, expdb_db)
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)

if not matched_tag_row:
msg = f"Setup {setup_id} does not have tag {tag!r}."
raise TagNotFoundError(msg)

if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups():
msg = (
f"You may not remove tag {tag!r} of setup {setup_id} because it was not created by you."
)
raise TagNotOwnedError(msg)

await database.setups.untag(setup_id, matched_tag_row.tag, expdb_db)
remaining_tags = [t.tag.casefold() for t in setup_tags if t != matched_tag_row]
return {"setup_untag": {"id": str(setup_id), "tag": remaining_tags}}
2 changes: 1 addition & 1 deletion tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def test_dataset_tag_invalid_tag_is_rejected(
py_api: httpx.AsyncClient,
) -> None:
new = await py_api.post(
f"/datasets/tag?api_key{ApiKey.ADMIN}",
f"/datasets/tag?api_key={ApiKey.ADMIN}",
json={"data_id": 1, "tag": tag},
)

Expand Down
149 changes: 149 additions & 0 deletions tests/routers/openml/migration/setups_migration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import contextlib
import re
from collections.abc import AsyncGenerator, Iterable
from http import HTTPStatus

import httpx
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection

from tests.users import OWNER_USER, ApiKey


@pytest.mark.parametrize(
"api_key",
[ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER],
ids=["Administrator", "non-owner", "tag owner"],
)
@pytest.mark.parametrize(
"other_tags",
[[], ["some_other_tag"], ["foo_some_other_tag", "bar_some_other_tag"]],
ids=["none", "one tag", "two tags"],
)
async def test_setup_untag_response_is_identical_when_tag_exists(
api_key: str,
other_tags: list[str],
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
expdb_test: AsyncConnection,
) -> None:
setup_id = 1
tag = "totally_new_tag_for_migration_testing"

@contextlib.asynccontextmanager
async def temporary_tags(
tags: Iterable[str], setup_id: int, *, persist: bool = False
) -> AsyncGenerator[None]:
for tag in tags:
await expdb_test.execute(
text(
"INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);" # noqa: E501
),
parameters={"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id},
)
if persist:
await expdb_test.commit()
yield
for tag in tags:
await expdb_test.execute(
text("DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag"),
parameters={"setup_id": setup_id, "tag": tag},
)
if persist:
await expdb_test.commit()

all_tags = [tag, *other_tags]
async with temporary_tags(tags=all_tags, setup_id=setup_id, persist=True):
original = await php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

# expdb_test transaction shared with Python API,
# no commit needed and rolled back at the end of the test
async with temporary_tags(tags=all_tags, setup_id=setup_id):
new = await py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
)

if new.status_code == HTTPStatus.OK:
assert original.status_code == new.status_code
original_untag = original.json()["setup_untag"]
new_untag = new.json()["setup_untag"]
assert original_untag["id"] == new_untag["id"]
if tags := original_untag.get("tag"):
if isinstance(tags, str):
assert tags == new_untag["tag"][0]
else:
assert tags == new_untag["tag"]
else:
assert new_untag["tag"] == []
return

code, message = original.json()["error"].values()
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
assert new.status_code == HTTPStatus.FORBIDDEN
assert code == new.json()["code"]
assert message == "Tag is not owned by you"
assert re.match(
r"You may not remove tag \S+ of setup \d+ because it was not created by you.",
new.json()["detail"],
)


async def test_setup_untag_response_is_identical_setup_doesnt_exist(
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
setup_id = 999999
tag = "totally_new_tag_for_migration_testing"
api_key = ApiKey.SOME_USER

original = await php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

new = await py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
)

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
assert new.status_code == HTTPStatus.NOT_FOUND
assert original.json()["error"]["message"] == "Entity not found."
assert original.json()["error"]["code"] == new.json()["code"]
assert re.match(
r"Setup \d+ not found.",
new.json()["detail"],
)


async def test_setup_untag_response_is_identical_tag_doesnt_exist(
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
setup_id = 1
tag = "totally_new_tag_for_migration_testing"
api_key = ApiKey.SOME_USER

original = await php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

new = await py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
)

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
assert new.status_code == HTTPStatus.NOT_FOUND
assert original.json()["error"]["code"] == new.json()["code"]
assert original.json()["error"]["message"] == "Tag not found."
assert re.match(
r"Setup \d+ does not have tag '\S+'.",
new.json()["detail"],
)
Loading
Loading