diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 41254863..dbc467a9 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -7,29 +7,38 @@ from core.conversions import _str_to_num from core.errors import FlowNotFoundError from routers.dependencies import expdb_connection -from schemas.flows import Flow, Parameter, Subflow +from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow router = APIRouter(prefix="/flows", tags=["flows"]) -@router.get("/exists/{name}/{external_version}") +@router.post("/exists") async def flow_exists( - name: str, - external_version: str, + body: FlowExistsBody, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["flow_id"], int]: """Check if a Flow with the name and version exists, if so, return the flow id.""" flow = await database.flows.get_by_name( - name=name, - external_version=external_version, + name=body.name, + external_version=body.external_version, expdb=expdb, ) if flow is None: - msg = f"Flow with name {name} and external version {external_version} not found." + msg = f"Flow with name {body.name} and external version {body.external_version} not found." raise FlowNotFoundError(msg) return {"flow_id": flow.id} +@router.get("/exists/{name}/{external_version}", deprecated=True) +async def flow_exists_get( + name: str, + external_version: str, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[Literal["flow_id"], int]: + """Use POST /flows/exists instead.""" + return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb) + + @router.get("/{flow_id}") async def get_flow( flow_id: int, diff --git a/src/schemas/flows.py b/src/schemas/flows.py index a6cd479c..04ef1455 100644 --- a/src/schemas/flows.py +++ b/src/schemas/flows.py @@ -6,6 +6,11 @@ from pydantic import BaseModel, ConfigDict, Field +class FlowExistsBody(BaseModel): + name: str = Field(min_length=1, max_length=1024) + external_version: str = Field(min_length=1, max_length=128) + + class Parameter(BaseModel): name: str default_value: Any diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index 400ec4c0..8b9fd3c2 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -8,6 +8,7 @@ from core.errors import FlowNotFoundError from routers.openml.flows import flow_exists +from schemas.flows import FlowExistsBody from tests.conftest import Flow @@ -28,7 +29,7 @@ async def test_flow_exists_calls_db_correctly( "database.flows.get_by_name", new_callable=mocker.AsyncMock, ) - await flow_exists(name, external_version, expdb_test) + await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test) mocked_db.assert_called_once_with( name=name, external_version=external_version, @@ -51,29 +52,37 @@ async def test_flow_exists_processes_found( new_callable=mocker.AsyncMock, return_value=fake_flow, ) - response = await flow_exists("name", "external_version", expdb_test) + response = await flow_exists( + FlowExistsBody(name="name", external_version="external_version"), expdb_test + ) assert response == {"flow_id": fake_flow.id} async def test_flow_exists_handles_flow_not_found( mocker: MockerFixture, expdb_test: AsyncConnection ) -> None: - mocker.patch("database.flows.get_by_name", return_value=None) + mocker.patch( + "database.flows.get_by_name", + new_callable=mocker.AsyncMock, + return_value=None, + ) with pytest.raises(FlowNotFoundError) as error: - await flow_exists("foo", "bar", expdb_test) + await flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.uri == FlowNotFoundError.uri async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None: - response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + response = await py_api.post( + "/flows/exists", json={"name": flow.name, "external_version": flow.external_version} + ) assert response.status_code == HTTPStatus.OK assert response.json() == {"flow_id": flow.id} async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None: name, version = "foo", "bar" - response = await py_api.get(f"/flows/exists/{name}/{version}") + response = await py_api.post("/flows/exists", json={"name": name, "external_version": version}) assert response.status_code == HTTPStatus.NOT_FOUND assert response.headers["content-type"] == "application/problem+json" error = response.json() @@ -82,6 +91,24 @@ async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None: assert version in error["detail"] +@pytest.mark.parametrize( + ("name", "external_version"), + [ + ("", "v1"), + ("some-flow", ""), + ], +) +async def test_flow_exists_rejects_empty_fields( + py_api: httpx.AsyncClient, + name: str, + external_version: str, +) -> None: + response = await py_api.post( + "/flows/exists", json={"name": name, "external_version": external_version} + ) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/flows/1") assert response.status_code == HTTPStatus.OK diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py index 38d11e8c..f23ece59 100644 --- a/tests/routers/openml/migration/flows_migration_test.py +++ b/tests/routers/openml/migration/flows_migration_test.py @@ -21,7 +21,7 @@ async def test_flow_exists_not( ) -> None: path = "exists/foo/bar" py_response, php_response = await asyncio.gather( - py_api.get(f"/flows/{path}"), + py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"}), php_api.get(f"/flow/{path}"), ) @@ -45,7 +45,13 @@ async def test_flow_exists( ) -> None: path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}" py_response, php_response = await asyncio.gather( - py_api.get(f"/flows/{path}"), + py_api.post( + "/flows/exists", + json={ + "name": persisted_flow.name, + "external_version": persisted_flow.external_version, + }, + ), php_api.get(f"/flow/{path}"), )