diff --git a/src/database/datasets.py b/src/database/datasets.py index 4e76dcf9..561869ab 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -68,6 +68,35 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) ) +async def get_tags(id_: int, connection: AsyncConnection) -> list[Row]: + row = await connection.execute( + text( + """ + SELECT * + FROM dataset_tag + WHERE id = :dataset_id + """, + ), + parameters={"dataset_id": id_}, + ) + return list(row.all()) + + +async def untag(id_: int, tag_: str, *, connection: AsyncConnection) -> None: + await connection.execute( + text( + """ + DELETE FROM dataset_tag + WHERE `id` = :dataset_id AND `tag` = :tag + """, + ), + parameters={ + "dataset_id": id_, + "tag": tag_, + }, + ) + + async def get_description( id_: int, connection: AsyncConnection, diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index d86ed848..d2ff7f1a 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -25,6 +25,8 @@ InternalError, NoResultsError, TagAlreadyExistsError, + TagNotFoundError, + TagNotOwnedError, ) from core.formatting import ( _csv_as_list, @@ -66,6 +68,39 @@ async def tag_dataset( } +@router.post( + path="/untag", +) +async def untag_dataset( + data_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, +) -> dict[str, dict[str, Any]]: + assert expdb_db is not None # noqa: S101 + if not await database.datasets.get(data_id, expdb_db): + msg = f"No dataset with id {data_id} found." + raise DatasetNotFoundError(msg) + + dataset_tags = await database.datasets.get_tags(data_id, expdb_db) + matched_tag_row = next((t for t in dataset_tags if t.tag.casefold() == tag.casefold()), None) + if matched_tag_row is None: + msg = f"Dataset {data_id} does not have tag {tag!r}." + raise TagNotFoundError(msg) + + if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups(): + msg = ( + f"You may not remove tag {tag!r} of dataset {data_id} " + "because it was not created by you." + ) + raise TagNotOwnedError(msg) + + await database.datasets.untag(data_id, matched_tag_row.tag, connection=expdb_db) + return { + "data_untag": {"id": str(data_id)}, + } + + class DatasetStatusFilter(StrEnum): ACTIVE = DatasetStatus.ACTIVE DEACTIVATED = DatasetStatus.DEACTIVATED diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 41254863..dbc467a9 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -7,29 +7,38 @@ from core.conversions import _str_to_num from core.errors import FlowNotFoundError from routers.dependencies import expdb_connection -from schemas.flows import Flow, Parameter, Subflow +from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow router = APIRouter(prefix="/flows", tags=["flows"]) -@router.get("/exists/{name}/{external_version}") +@router.post("/exists") async def flow_exists( - name: str, - external_version: str, + body: FlowExistsBody, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["flow_id"], int]: """Check if a Flow with the name and version exists, if so, return the flow id.""" flow = await database.flows.get_by_name( - name=name, - external_version=external_version, + name=body.name, + external_version=body.external_version, expdb=expdb, ) if flow is None: - msg = f"Flow with name {name} and external version {external_version} not found." + msg = f"Flow with name {body.name} and external version {body.external_version} not found." raise FlowNotFoundError(msg) return {"flow_id": flow.id} +@router.get("/exists/{name}/{external_version}", deprecated=True) +async def flow_exists_get( + name: str, + external_version: str, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[Literal["flow_id"], int]: + """Use POST /flows/exists instead.""" + return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb) + + @router.get("/{flow_id}") async def get_flow( flow_id: int, diff --git a/src/schemas/flows.py b/src/schemas/flows.py index a6cd479c..04ef1455 100644 --- a/src/schemas/flows.py +++ b/src/schemas/flows.py @@ -6,6 +6,11 @@ from pydantic import BaseModel, ConfigDict, Field +class FlowExistsBody(BaseModel): + name: str = Field(min_length=1, max_length=1024) + external_version: str = Field(min_length=1, max_length=128) + + class Parameter(BaseModel): name: str default_value: Any diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 25042c89..586428f2 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -89,3 +89,91 @@ async def test_dataset_tag_invalid_tag_is_rejected( assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert new.json()["detail"][0]["loc"] == ["body", "tag"] + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_dataset_untag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/datasets/untag{apikey}", + json={"data_id": 1, "tag": "study_14"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["type"] == AuthenticationFailedError.uri + assert error["code"] == "103" + + +async def test_dataset_untag(py_api: httpx.AsyncClient, expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "temp_dataset_untag" + await py_api.post( + f"/datasets/tag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"data_untag": {"id": str(dataset_id)}} + assert tag not in await get_tags_for(id_=dataset_id, connection=expdb_test) + + +async def test_dataset_untag_rejects_other_user(py_api: httpx.AsyncClient) -> None: + dataset_id = 1 + tag = "temp_dataset_untag_not_owned" + await py_api.post( + f"/datasets/tag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN + assert response.json()["code"] == "476" + assert "not created by you" in response.json()["detail"] + + cleanup = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert cleanup.status_code == HTTPStatus.OK + + +async def test_dataset_untag_fails_if_tag_does_not_exist(py_api: httpx.AsyncClient) -> None: + dataset_id = 1 + tag = "definitely_not_a_dataset_tag" + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.ADMIN}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["code"] == "475" + assert "does not have tag" in response.json()["detail"] + + +@pytest.mark.parametrize( + "tag", + ["", "h@", " a", "a" * 65], + ids=["too short", "@", "space", "too long"], +) +async def test_dataset_untag_invalid_tag_is_rejected( + tag: str, + py_api: httpx.AsyncClient, +) -> None: + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.ADMIN}", + json={"data_id": 1, "tag": tag}, + ) + + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json()["detail"][0]["loc"] == ["body", "tag"] diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index 400ec4c0..b1841d36 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -4,10 +4,12 @@ import httpx import pytest from pytest_mock import MockerFixture +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import FlowNotFoundError from routers.openml.flows import flow_exists +from schemas.flows import FlowExistsBody from tests.conftest import Flow @@ -28,7 +30,7 @@ async def test_flow_exists_calls_db_correctly( "database.flows.get_by_name", new_callable=mocker.AsyncMock, ) - await flow_exists(name, external_version, expdb_test) + await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test) mocked_db.assert_called_once_with( name=name, external_version=external_version, @@ -51,29 +53,42 @@ async def test_flow_exists_processes_found( new_callable=mocker.AsyncMock, return_value=fake_flow, ) - response = await flow_exists("name", "external_version", expdb_test) + response = await flow_exists( + FlowExistsBody(name="name", external_version="external_version"), + expdb_test, + ) assert response == {"flow_id": fake_flow.id} async def test_flow_exists_handles_flow_not_found( mocker: MockerFixture, expdb_test: AsyncConnection ) -> None: - mocker.patch("database.flows.get_by_name", return_value=None) + mocker.patch( + "database.flows.get_by_name", + new_callable=mocker.AsyncMock, + return_value=None, + ) with pytest.raises(FlowNotFoundError) as error: - await flow_exists("foo", "bar", expdb_test) + await flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.uri == FlowNotFoundError.uri async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None: - response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + response = await py_api.post( + "/flows/exists", + json={"name": flow.name, "external_version": flow.external_version}, + ) assert response.status_code == HTTPStatus.OK assert response.json() == {"flow_id": flow.id} async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None: name, version = "foo", "bar" - response = await py_api.get(f"/flows/exists/{name}/{version}") + response = await py_api.post( + "/flows/exists", + json={"name": name, "external_version": version}, + ) assert response.status_code == HTTPStatus.NOT_FOUND assert response.headers["content-type"] == "application/problem+json" error = response.json() @@ -82,6 +97,60 @@ async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None: assert version in error["detail"] +@pytest.mark.parametrize( + ("name", "external_version"), + [ + ("", "v1"), + ("some-flow", ""), + ], +) +async def test_flow_exists_rejects_empty_fields( + py_api: httpx.AsyncClient, + name: str, + external_version: str, +) -> None: + response = await py_api.post( + "/flows/exists", + json={"name": name, "external_version": external_version}, + ) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +async def test_flow_exists_with_uri_unsafe_chars( + py_api: httpx.AsyncClient, + expdb_test: AsyncConnection, +) -> None: + name = "sklearn.pipeline.Pipeline(steps=[('a','b')])" + external_version = "v1" + await expdb_test.execute( + text( + """ + INSERT INTO implementation(fullname,name,version,external_version,uploadDate) + VALUES (:fullname,:name,2,:external_version,'2024-02-02 02:23:23'); + """, + ), + parameters={ + "fullname": name, + "name": name, + "external_version": external_version, + }, + ) + result = await expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")) + (flow_id,) = result.one() + response = await py_api.post( + "/flows/exists", + json={"name": name, "external_version": external_version}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"flow_id": flow_id} + + +async def test_flow_exists_get_deprecated(flow: Flow, py_api: httpx.AsyncClient) -> None: + response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"flow_id": flow.id} + + async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/flows/1") assert response.status_code == HTTPStatus.OK diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index 5ff6fe86..6cf39001 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -226,6 +226,61 @@ async def test_dataset_tag_response_is_identical( assert original == new +@pytest.mark.parametrize( + "dataset_id", + [1, 2, 3, 101, 131], +) +@pytest.mark.parametrize( + "api_key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["Administrator", "regular user", "possible owner"], +) +@pytest.mark.parametrize( + "tag", + ["study_14", "study_15"], +) +async def test_dataset_untag_response_is_identical( + dataset_id: int, + tag: str, + api_key: str, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + original = await php_api.post( + "/data/untag", + data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, + ) + if original.status_code == HTTPStatus.OK: + await php_api.post( + "/data/tag", + data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, + ) + + new = await py_api.post( + f"/datasets/untag?api_key={api_key}", + json={"data_id": dataset_id, "tag": tag}, + ) + + if new.status_code == HTTPStatus.OK: + assert original.status_code == new.status_code, original.json() + assert original.json() == new.json() + return + + code, message = original.json()["error"].values() + if message == "Tag is not owned by you": + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.FORBIDDEN + assert code == new.json()["code"] + assert "not created by you" in new.json()["detail"] + return + + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.NOT_FOUND + assert code == new.json()["code"] + assert message == "Tag not found." + assert "does not have tag" in new.json()["detail"] + + @pytest.mark.parametrize( "data_id", list(range(1, 130)),