Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,35 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
)


async def get_tags(id_: int, connection: AsyncConnection) -> list[Row]:
row = await connection.execute(
text(
"""
SELECT *
FROM dataset_tag
WHERE id = :dataset_id
""",
),
parameters={"dataset_id": id_},
)
return list(row.all())


async def untag(id_: int, tag_: str, *, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
DELETE FROM dataset_tag
WHERE `id` = :dataset_id AND `tag` = :tag
""",
),
parameters={
"dataset_id": id_,
"tag": tag_,
},
)


async def get_description(
id_: int,
connection: AsyncConnection,
Expand Down
35 changes: 35 additions & 0 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
InternalError,
NoResultsError,
TagAlreadyExistsError,
TagNotFoundError,
TagNotOwnedError,
)
from core.formatting import (
_csv_as_list,
Expand Down Expand Up @@ -66,6 +68,39 @@ async def tag_dataset(
}


@router.post(
path="/untag",
)
async def untag_dataset(
data_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> dict[str, dict[str, Any]]:
assert expdb_db is not None # noqa: S101
if not await database.datasets.get(data_id, expdb_db):
msg = f"No dataset with id {data_id} found."
raise DatasetNotFoundError(msg)

dataset_tags = await database.datasets.get_tags(data_id, expdb_db)
matched_tag_row = next((t for t in dataset_tags if t.tag.casefold() == tag.casefold()), None)
if matched_tag_row is None:
msg = f"Dataset {data_id} does not have tag {tag!r}."
raise TagNotFoundError(msg)

if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups():
msg = (
f"You may not remove tag {tag!r} of dataset {data_id} "
"because it was not created by you."
)
raise TagNotOwnedError(msg)

await database.datasets.untag(data_id, matched_tag_row.tag, connection=expdb_db)
Comment on lines +85 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Make the untag operation atomic.

This authorizes against matched_tag_row in one query and then does a separate delete in another. Two concurrent untag requests can both pass the precheck; the loser deletes 0 rows but still returns 200 OK. Please collapse this into a single authorized delete, or have database.datasets.untag(...) return rowcount and fail when nothing was actually removed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/routers/openml/datasets.py` around lines 85 - 98, The current flow
fetches dataset_tags and authorizes based on matched_tag_row, then calls
database.datasets.untag separately, allowing a race where both requests pass the
precheck; make the untag operation atomic by moving authorization into the
delete: modify database.datasets.untag (or add a new method) to accept the
deleting user's id (user.user_id) and perform a single SQL DELETE ... WHERE
data_id = ? AND tag = ? AND (uploader = ? OR EXISTS(SELECT 1 FROM user_groups
WHERE user_id = ? AND group = 'ADMIN')), returning the rowcount; in the route
replace the precheck+untag sequence with a single call to that atomic method and
raise TagNotFoundError or TagNotOwnedError when rowcount == 0 so the request
fails when nothing was removed.

return {
"data_untag": {"id": str(data_id)},
}


class DatasetStatusFilter(StrEnum):
ACTIVE = DatasetStatus.ACTIVE
DEACTIVATED = DatasetStatus.DEACTIVATED
Expand Down
23 changes: 16 additions & 7 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,38 @@
from core.conversions import _str_to_num
from core.errors import FlowNotFoundError
from routers.dependencies import expdb_connection
from schemas.flows import Flow, Parameter, Subflow
from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow

router = APIRouter(prefix="/flows", tags=["flows"])


@router.get("/exists/{name}/{external_version}")
@router.post("/exists")
async def flow_exists(
name: str,
external_version: str,
body: FlowExistsBody,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Check if a Flow with the name and version exists, if so, return the flow id."""
flow = await database.flows.get_by_name(
name=name,
external_version=external_version,
name=body.name,
external_version=body.external_version,
expdb=expdb,
)
if flow is None:
msg = f"Flow with name {name} and external version {external_version} not found."
msg = f"Flow with name {body.name} and external version {body.external_version} not found."
raise FlowNotFoundError(msg)
return {"flow_id": flow.id}


@router.get("/exists/{name}/{external_version}", deprecated=True)
async def flow_exists_get(
name: str,
external_version: str,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Use POST /flows/exists instead."""
return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb)
Comment on lines +34 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

In FastAPI with Pydantic v2, if a route handler manually instantiates a Pydantic model and that constructor raises ValidationError, is the response converted to 422 like request validation, or does it surface as a 500 unless explicitly handled?

💡 Result:

If you manually instantiate/validate a Pydantic model inside a FastAPI route handler and it raises pydantic.ValidationError, FastAPI will treat that as an internal error and the client will get HTTP 500, not an automatic 422.

FastAPI only auto-converts request parsing/validation failures into RequestValidationError (which has a default handler returning 422). A plain Pydantic ValidationError “in your response or anywhere in your code (not in the client’s request)” is considered a bug and results in 500 unless you catch/handle it yourself. [1]

So, if you want a 422, you must explicitly handle it, e.g. catch ValidationError and raise HTTPException(status_code=422, detail=e.errors(...)). [2]

Sources: [1] [2]


🏁 Script executed:

cd src/routers/openml && head -100 flows.py | cat -n

Repository: openml/server-api

Length of output: 3829


🏁 Script executed:

rg "FlowExistsBody" --type py -A 5 -B 2

Repository: openml/server-api

Length of output: 4002


🏁 Script executed:

rg "def flow_exists" --type py -A 15

Repository: openml/server-api

Length of output: 2056


Add Path constraints to the GET alias to match POST validation and ensure consistent 422 responses.

The GET route accepts name and external_version as bare path parameters with no validation, then manually instantiates FlowExistsBody(...) inside the handler. When Pydantic validation fails in a manually instantiated model, FastAPI treats it as an unhandled error returning 500, while the POST route returns 422 for the same invalid input. Apply the same constraints to path parameters as defined in FlowExistsBody (min_length=1, max_length=1024 for name; min_length=1, max_length=128 for external_version).

🛠️ Suggested fix
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, Path
@@
 async def flow_exists_get(
-    name: str,
-    external_version: str,
+    name: Annotated[str, Path(min_length=1, max_length=1024)],
+    external_version: Annotated[str, Path(min_length=1, max_length=128)],
     expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
 ) -> dict[Literal["flow_id"], int]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@router.get("/exists/{name}/{external_version}", deprecated=True)
async def flow_exists_get(
name: str,
external_version: str,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Use POST /flows/exists instead."""
return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb)
`@router.get`("/exists/{name}/{external_version}", deprecated=True)
async def flow_exists_get(
name: Annotated[str, Path(min_length=1, max_length=1024)],
external_version: Annotated[str, Path(min_length=1, max_length=128)],
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Use POST /flows/exists instead."""
return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/routers/openml/flows.py` around lines 34 - 41, The GET handler
flow_exists_get currently accepts raw path params and constructs
FlowExistsBody(name=..., external_version=...), causing Pydantic validation
errors to become 500s; update the function signature to apply the same Path
constraints as FlowExistsBody (name: str = Path(..., min_length=1,
max_length=1024), external_version: str = Path(..., min_length=1,
max_length=128)) so FastAPI validates inputs and returns 422 on bad input, and
add the required Path import from fastapi.



@router.get("/{flow_id}")
async def get_flow(
flow_id: int,
Expand Down
5 changes: 5 additions & 0 deletions src/schemas/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from pydantic import BaseModel, ConfigDict, Field


class FlowExistsBody(BaseModel):
name: str = Field(min_length=1, max_length=1024)
external_version: str = Field(min_length=1, max_length=128)


class Parameter(BaseModel):
name: str
default_value: Any
Expand Down
88 changes: 88 additions & 0 deletions tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,91 @@ async def test_dataset_tag_invalid_tag_is_rejected(

assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert new.json()["detail"][0]["loc"] == ["body", "tag"]


@pytest.mark.parametrize(
"key",
[None, ApiKey.INVALID],
ids=["no authentication", "invalid key"],
)
async def test_dataset_untag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None:
apikey = "" if key is None else f"?api_key={key}"
response = await py_api.post(
f"/datasets/untag{apikey}",
json={"data_id": 1, "tag": "study_14"},
)
assert response.status_code == HTTPStatus.UNAUTHORIZED
assert response.headers["content-type"] == "application/problem+json"
error = response.json()
assert error["type"] == AuthenticationFailedError.uri
assert error["code"] == "103"


async def test_dataset_untag(py_api: httpx.AsyncClient, expdb_test: AsyncConnection) -> None:
dataset_id = 1
tag = "temp_dataset_untag"
await py_api.post(
f"/datasets/tag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)

response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"data_untag": {"id": str(dataset_id)}}
assert tag not in await get_tags_for(id_=dataset_id, connection=expdb_test)


async def test_dataset_untag_rejects_other_user(py_api: httpx.AsyncClient) -> None:
dataset_id = 1
tag = "temp_dataset_untag_not_owned"
await py_api.post(
f"/datasets/tag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)

response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.OWNER_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json()["code"] == "476"
assert "not created by you" in response.json()["detail"]

cleanup = await py_api.post(
f"/datasets/untag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert cleanup.status_code == HTTPStatus.OK


async def test_dataset_untag_fails_if_tag_does_not_exist(py_api: httpx.AsyncClient) -> None:
dataset_id = 1
tag = "definitely_not_a_dataset_tag"
response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.ADMIN}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["code"] == "475"
assert "does not have tag" in response.json()["detail"]


@pytest.mark.parametrize(
"tag",
["", "h@", " a", "a" * 65],
ids=["too short", "@", "space", "too long"],
)
async def test_dataset_untag_invalid_tag_is_rejected(
tag: str,
py_api: httpx.AsyncClient,
) -> None:
response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.ADMIN}",
json={"data_id": 1, "tag": tag},
)

assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json()["detail"][0]["loc"] == ["body", "tag"]
81 changes: 75 additions & 6 deletions tests/routers/openml/flows_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import httpx
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import FlowNotFoundError
from routers.openml.flows import flow_exists
from schemas.flows import FlowExistsBody
from tests.conftest import Flow


Expand All @@ -28,7 +30,7 @@ async def test_flow_exists_calls_db_correctly(
"database.flows.get_by_name",
new_callable=mocker.AsyncMock,
)
await flow_exists(name, external_version, expdb_test)
await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test)
mocked_db.assert_called_once_with(
name=name,
external_version=external_version,
Expand All @@ -51,29 +53,42 @@ async def test_flow_exists_processes_found(
new_callable=mocker.AsyncMock,
return_value=fake_flow,
)
response = await flow_exists("name", "external_version", expdb_test)
response = await flow_exists(
FlowExistsBody(name="name", external_version="external_version"),
expdb_test,
)
assert response == {"flow_id": fake_flow.id}


async def test_flow_exists_handles_flow_not_found(
mocker: MockerFixture, expdb_test: AsyncConnection
) -> None:
mocker.patch("database.flows.get_by_name", return_value=None)
mocker.patch(
"database.flows.get_by_name",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(FlowNotFoundError) as error:
await flow_exists("foo", "bar", expdb_test)
await flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test)
assert error.value.status_code == HTTPStatus.NOT_FOUND
assert error.value.uri == FlowNotFoundError.uri


async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None:
response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
response = await py_api.post(
"/flows/exists",
json={"name": flow.name, "external_version": flow.external_version},
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}


async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None:
name, version = "foo", "bar"
response = await py_api.get(f"/flows/exists/{name}/{version}")
response = await py_api.post(
"/flows/exists",
json={"name": name, "external_version": version},
)
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.headers["content-type"] == "application/problem+json"
error = response.json()
Expand All @@ -82,6 +97,60 @@ async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None:
assert version in error["detail"]


@pytest.mark.parametrize(
("name", "external_version"),
[
("", "v1"),
("some-flow", ""),
],
)
async def test_flow_exists_rejects_empty_fields(
py_api: httpx.AsyncClient,
name: str,
external_version: str,
) -> None:
response = await py_api.post(
"/flows/exists",
json={"name": name, "external_version": external_version},
)
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY


async def test_flow_exists_with_uri_unsafe_chars(
py_api: httpx.AsyncClient,
expdb_test: AsyncConnection,
) -> None:
name = "sklearn.pipeline.Pipeline(steps=[('a','b')])"
external_version = "v1"
await expdb_test.execute(
text(
"""
INSERT INTO implementation(fullname,name,version,external_version,uploadDate)
VALUES (:fullname,:name,2,:external_version,'2024-02-02 02:23:23');
""",
),
parameters={
"fullname": name,
"name": name,
"external_version": external_version,
},
)
result = await expdb_test.execute(text("""SELECT LAST_INSERT_ID();"""))
(flow_id,) = result.one()
response = await py_api.post(
"/flows/exists",
json={"name": name, "external_version": external_version},
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow_id}


async def test_flow_exists_get_deprecated(flow: Flow, py_api: httpx.AsyncClient) -> None:
response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}


async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None:
response = await py_api.get("/flows/1")
assert response.status_code == HTTPStatus.OK
Expand Down
Loading