Skip to content
23 changes: 18 additions & 5 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
import database.flows
from core.conversions import _str_to_num
from routers.dependencies import expdb_connection
from schemas.flows import Flow, Parameter, Subflow
from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow

router = APIRouter(prefix="/flows", tags=["flows"])


@router.get("/exists/{name}/{external_version}")
@router.post("/exists")
def flow_exists(
name: str,
external_version: str,
body: FlowExistsBody,
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Check if a Flow with the name and version exists, if so, return the flow id."""
flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb)
flow = database.flows.get_by_name(
name=body.name,
external_version=body.external_version,
expdb=expdb,
)
if flow is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
Expand All @@ -28,6 +31,16 @@ def flow_exists(
return {"flow_id": flow.id}


@router.get("/exists/{name}/{external_version}", deprecated=True)
def flow_exists_get(
name: str,
external_version: str,
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Deprecated: use POST /flows/exists instead."""
return flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb)


@router.get("/{flow_id}")
def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow:
flow = database.flows.get(flow_id, expdb)
Expand Down
5 changes: 5 additions & 0 deletions src/schemas/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from pydantic import BaseModel, ConfigDict, Field


class FlowExistsBody(BaseModel):
name: str = Field(min_length=1, max_length=1024)
external_version: str = Field(min_length=1, max_length=128)


class Parameter(BaseModel):
name: str
default_value: Any
Expand Down
33 changes: 28 additions & 5 deletions tests/routers/openml/flows_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette.testclient import TestClient

from routers.openml.flows import flow_exists
from schemas.flows import FlowExistsBody
from tests.conftest import Flow


Expand All @@ -25,7 +26,7 @@ def test_flow_exists_calls_db_correctly(
mocker: MockerFixture,
) -> None:
mocked_db = mocker.patch("database.flows.get_by_name")
flow_exists(name, external_version, expdb_test)
flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test)
mocked_db.assert_called_once_with(
name=name,
external_version=external_version,
Expand All @@ -47,30 +48,52 @@ def test_flow_exists_processes_found(
"database.flows.get_by_name",
return_value=fake_flow,
)
response = flow_exists("name", "external_version", expdb_test)
response = flow_exists(
FlowExistsBody(name="name", external_version="external_version"), expdb_test
)
assert response == {"flow_id": fake_flow.id}


def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None:
mocker.patch("database.flows.get_by_name", return_value=None)
with pytest.raises(HTTPException) as error:
flow_exists("foo", "bar", expdb_test)
flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test)
assert error.value.status_code == HTTPStatus.NOT_FOUND
assert error.value.detail == "Flow not found."


def test_flow_exists(flow: Flow, py_api: TestClient) -> None:
response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
response = py_api.post(
"/flows/exists", json={"name": flow.name, "external_version": flow.external_version}
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}


def test_flow_exists_not_exists(py_api: TestClient) -> None:
response = py_api.get("/flows/exists/foo/bar")
response = py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"})
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["detail"] == "Flow not found."


@pytest.mark.parametrize(
("name", "external_version"),
[
("", "v1"),
("some-flow", ""),
],
)
def test_flow_exists_rejects_empty_fields(
py_api: TestClient,
name: str,
external_version: str,
) -> None:
response = py_api.post(
"/flows/exists", json={"name": name, "external_version": external_version}
)
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY


def test_get_flow_no_subflow(py_api: TestClient) -> None:
response = py_api.get("/flows/1")
assert response.status_code == HTTPStatus.OK
Expand Down
15 changes: 9 additions & 6 deletions tests/routers/openml/migration/flows_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_flow_exists_not(
py_api: TestClient,
php_api: TestClient,
) -> None:
path = "exists/foo/bar"
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")
py_response = py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"})
php_response = php_api.get("/flow/exists/foo/bar")

assert py_response.status_code == HTTPStatus.NOT_FOUND
assert php_response.status_code == HTTPStatus.OK
Expand All @@ -36,9 +35,13 @@ def test_flow_exists(
py_api: TestClient,
php_api: TestClient,
) -> None:
path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}"
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")
py_response = py_api.post(
"/flows/exists",
json={"name": persisted_flow.name, "external_version": persisted_flow.external_version},
)
php_response = php_api.get(
f"/flow/exists/{persisted_flow.name}/{persisted_flow.external_version}"
)

assert py_response.status_code == php_response.status_code, php_response.content

Expand Down