diff --git a/pyproject.toml b/pyproject.toml index d8078bf1..e1b7e609 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ ] description = "The Python-based REST API for OpenML." readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.14" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/src/core/access.py b/src/core/access.py index 558643f5..871fb48f 100644 --- a/src/core/access.py +++ b/src/core/access.py @@ -1,10 +1,11 @@ -from typing import Any - -from sqlalchemy.engine import Row +from typing import TYPE_CHECKING, Any from database.users import User from schemas.datasets.openml import Visibility +if TYPE_CHECKING: + from sqlalchemy.engine import Row + async def _user_has_access( dataset: Row[Any], diff --git a/src/core/errors.py b/src/core/errors.py index 69d7e0c7..5fe54376 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -5,11 +5,14 @@ """ from http import HTTPStatus +from typing import TYPE_CHECKING -from fastapi import Request -from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +if TYPE_CHECKING: + from fastapi import Request + from fastapi.exceptions import RequestValidationError + # ============================================================================= # Base Exception # ============================================================================= diff --git a/src/core/formatting.py b/src/core/formatting.py index f954e81d..406659fa 100644 --- a/src/core/formatting.py +++ b/src/core/formatting.py @@ -1,10 +1,12 @@ import html - -from sqlalchemy.engine import Row +from typing import TYPE_CHECKING from config import load_routing_configuration from schemas.datasets.openml import DatasetFileFormat +if TYPE_CHECKING: + from sqlalchemy.engine import Row + def _str_to_bool(string: str) -> bool: if string.casefold() in ["true", "1", "yes", "y"]: diff --git a/src/core/logging.py b/src/core/logging.py index b35270d9..6546f714 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -5,13 +5,16 @@ import uuid from collections.abc import Awaitable, Callable from pathlib import Path +from typing import TYPE_CHECKING from loguru import logger -from starlette.requests import Request -from starlette.responses import Response from config import load_configuration +if TYPE_CHECKING: + from starlette.requests import Request + from starlette.responses import Response + def setup_log_sinks(configuration_file: Path | None = None) -> None: """Configure loguru based on app configuration.""" diff --git a/src/database/datasets.py b/src/database/datasets.py index d6f91706..efa77459 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,11 +2,10 @@ import datetime from collections import defaultdict +from typing import TYPE_CHECKING from sqlalchemy import text -from sqlalchemy.engine import Row from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncConnection from database.exceptions import ( _DUPLICATE_ENTRY, @@ -16,6 +15,10 @@ ) from schemas.datasets.openml import Feature +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + async def get(id_: int, connection: AsyncConnection) -> Row | None: row = await connection.execute( @@ -45,6 +48,33 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: return row.one_or_none() +async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row | None: + return ( + await connection.execute( + text( + """ + SELECT * + FROM dataset_tag + WHERE id = :dataset_id AND tag = :tag + """, + ), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + ).first() + + +async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None: + await connection.execute( + text( + """ + DELETE FROM dataset_tag + WHERE id = :dataset_id AND tag = :tag + """, + ), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]: row = await connection.execute( text( diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 74faf59b..382653fc 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,12 +1,14 @@ from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection from core.formatting import _str_to_bool from schemas.datasets.openml import EstimationProcedure +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]: rows = await connection.execute( diff --git a/src/database/flows.py b/src/database/flows.py index 79bb6e5b..ed022c40 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,8 +1,10 @@ from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: diff --git a/src/database/qualities.py b/src/database/qualities.py index 08647f41..9180b137 100644 --- a/src/database/qualities.py +++ b/src/database/qualities.py @@ -1,11 +1,14 @@ from collections import defaultdict from collections.abc import Iterable +from typing import TYPE_CHECKING from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from schemas.datasets.openml import Quality +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]: row = await connection.execute( diff --git a/src/database/runs.py b/src/database/runs.py index acf7a532..6eef0b3b 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -1,10 +1,12 @@ """Database queries for run-related data.""" from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection async def exist(id_: int, expdb: AsyncConnection) -> bool: diff --git a/src/database/setups.py b/src/database/setups.py index 74498478..c8651d93 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -1,8 +1,12 @@ """All database operations that directly operate on setups.""" +from typing import TYPE_CHECKING + from sqlalchemy import text -from sqlalchemy.engine import Row, RowMapping -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.engine import Row, RowMapping + from sqlalchemy.ext.asyncio import AsyncConnection async def get(setup_id: int, connection: AsyncConnection) -> Row | None: diff --git a/src/database/studies.py b/src/database/studies.py index 286f7988..3c7c8e96 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -1,14 +1,16 @@ import re from collections.abc import Sequence from datetime import UTC, datetime -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User from schemas.study import CreateStudy, StudyType +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None: row = await connection.execute( diff --git a/src/database/tasks.py b/src/database/tasks.py index e9670d26..ba4eeccb 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,8 +1,10 @@ from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection async def get(id_: int, expdb: AsyncConnection) -> Row | None: diff --git a/src/database/users.py b/src/database/users.py index 0b09fb0d..8a812d69 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -1,13 +1,15 @@ import dataclasses from enum import IntEnum -from typing import Annotated, Self +from typing import TYPE_CHECKING, Annotated, Self from pydantic import StringConstraints from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from config import load_configuration +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + # If `allow_test_api_keys` is set, the key may also be one of `normaluser`, # `normaluser2`, or `abc` (admin). api_key_pattern = r"^[0-9a-fA-F]{32}$" diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index dba12d8a..ca4a9654 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,15 +1,17 @@ from collections.abc import AsyncGenerator, AsyncIterator -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import Depends from loguru import logger from pydantic import BaseModel, Field -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.setup import expdb_database, user_database from database.users import APIKey, User +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def expdb_connection() -> AsyncIterator[AsyncConnection]: engine = expdb_database() diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index 998f940e..2749664f 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -5,10 +5,9 @@ """ import asyncio -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncConnection import config from core.errors import ServiceNotFoundError @@ -28,6 +27,9 @@ Quality, ) +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/mldcat_ap", tags=["MLDCAT-AP"]) _configuration = config.load_configuration() _server_url = ( diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 68d86aed..12b796f2 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -2,13 +2,12 @@ import re from datetime import datetime from enum import StrEnum -from typing import Annotated, Any, Literal, NamedTuple +from http import HTTPStatus +from typing import TYPE_CHECKING, Annotated, Any, Literal, NamedTuple, NotRequired, TypedDict -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Query from loguru import logger from sqlalchemy import bindparam, text -from sqlalchemy.engine import Row -from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -26,6 +25,8 @@ InternalError, NoResultsError, TagAlreadyExistsError, + TagNotFoundError, + TagNotOwnedError, ) from core.formatting import ( _csv_as_list, @@ -50,6 +51,10 @@ ) from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -58,7 +63,7 @@ ) async def tag_dataset( data_id: Annotated[Identifier, Body()], - tag: Annotated[str, SystemString64], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, Any]]: @@ -80,6 +85,50 @@ async def tag_dataset( } +class TagInfo(TypedDict): + id: str + tag: NotRequired[SystemString64 | list[SystemString64]] + + +@router.post(path="/untag", deprecated=True) +async def untag_dataset_like_php( + data_id: Annotated[Identifier, Body()], + tag: Annotated[SystemString64, Body()], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[Literal["data_untag"], TagInfo]: + await untag_dataset(data_id, tag, user, expdb_db) + tags = await database.datasets.get_tags_for(id_=data_id, connection=expdb_db) + tag_info: TagInfo = {"id": str(data_id)} + if len(tags) == 1: + tag_info["tag"] = tags[0] + elif tags: + tag_info["tag"] = tags + return {"data_untag": tag_info} + + +@router.delete(path="/{identifier}/tag", status_code=HTTPStatus.NO_CONTENT) +async def untag_dataset( + identifier: Identifier, + tag: Annotated[SystemString64, Query()], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> None: + dataset_tag = await database.datasets.get_tag(identifier, tag, expdb_db) + if not dataset_tag: + try: + await _get_dataset_raise_otherwise(identifier, user, expdb_db) + except DatasetNotFoundError, DatasetNoAccessError: + msg = f"Cannot remove {tag!r}, because dataset {identifier} is not found." + raise DatasetNotFoundError(msg, code=472) from None + msg = f"Tag {tag!r} for dataset {identifier} not found." + raise TagNotFoundError(msg) + if dataset_tag.uploader != user.user_id and not (await user.is_admin()): + msg = f"You are not allowed to remove {tag!r} from dataset {identifier}." + raise TagNotOwnedError(msg) + await database.datasets.delete_tag(identifier, tag, expdb_db) + + class DatasetStatusFilter(StrEnum): ACTIVE = DatasetStatus.ACTIVE DEACTIVATED = DatasetStatus.DEACTIVATED @@ -108,27 +157,27 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.get(path="/list") async def list_datasets( # noqa: PLR0913, C901 pagination: Annotated[Pagination, Body(default_factory=Pagination)], - data_name: Annotated[str | None, CasualString128] = None, - tag: Annotated[str | None, SystemString64] = None, + data_name: Annotated[CasualString128 | None, Body()] = None, + tag: Annotated[SystemString64 | None, Body()] = None, data_version: Annotated[ - int | None, + Identifier | None, Body(description="The dataset version to include in the search."), ] = None, uploader: Annotated[ - int | None, + Identifier | None, Body(description="User id of the uploader whose datasets to include in the search."), ] = None, data_id: Annotated[ - list[int] | None, + list[Identifier] | None, Body( description="The dataset(s) to include in the search. " "If none are specified, all datasets are included.", ), ] = None, - number_instances: Annotated[str | None, IntegerRange] = None, - number_features: Annotated[str | None, IntegerRange] = None, - number_classes: Annotated[str | None, IntegerRange] = None, - number_missing_values: Annotated[str | None, IntegerRange] = None, + number_instances: Annotated[IntegerRange | None, Body()] = None, + number_features: Annotated[IntegerRange | None, Body()] = None, + number_classes: Annotated[IntegerRange | None, Body()] = None, + number_missing_values: Annotated[IntegerRange | None, Body()] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, @@ -266,7 +315,7 @@ class ProcessingInformation(NamedTuple): async def _get_processing_information( - dataset_id: int, + dataset_id: Identifier, connection: AsyncConnection, ) -> ProcessingInformation: """Return processing information, if any. Otherwise, all fields `None`.""" @@ -285,7 +334,7 @@ async def _get_processing_information( async def _get_dataset_raise_otherwise( - dataset_id: int, + dataset_id: Identifier, user: User | None, expdb: AsyncConnection, ) -> Row[Any]: @@ -306,7 +355,7 @@ async def _get_dataset_raise_otherwise( @router.get("/features/{dataset_id}", response_model_exclude_none=True) async def get_dataset_features( - dataset_id: int, + dataset_id: Identifier, user: Annotated[User | None, Depends(fetch_user)] = None, expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[Feature]: @@ -349,7 +398,7 @@ async def get_dataset_features( path="/status/update", ) async def update_dataset_status( - dataset_id: Annotated[int, Body()], + dataset_id: Annotated[Identifier, Body()], status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)], @@ -403,7 +452,7 @@ async def update_dataset_status( description="Get meta-data for dataset with ID `dataset_id`.", ) async def get_dataset( - dataset_id: int, + dataset_id: Identifier, user: Annotated[User | None, Depends(fetch_user)] = None, user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py index b07c2c0e..d8532dcb 100644 --- a/src/routers/openml/estimation_procedure.py +++ b/src/routers/openml/estimation_procedure.py @@ -1,12 +1,14 @@ -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection from schemas.datasets.openml import EstimationProcedure +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/estimationprocedure", tags=["estimationprocedure"]) diff --git a/src/routers/openml/evaluations.py b/src/routers/openml/evaluations.py index f6650b36..f2891312 100644 --- a/src/routers/openml/evaluations.py +++ b/src/routers/openml/evaluations.py @@ -1,11 +1,13 @@ -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/evaluationmeasure", tags=["evaluationmeasure"]) diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 44c23389..b3c3f219 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -1,15 +1,18 @@ import asyncio -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from core.conversions import _str_to_num from core.errors import FlowNotFoundError from routers.dependencies import expdb_connection +from routers.types import Identifier from schemas.flows import Flow, Parameter, Subflow +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/flows", tags=["flows"]) @@ -33,7 +36,7 @@ async def flow_exists( @router.get("/{flow_id}") async def get_flow( - flow_id: int, + flow_id: Identifier, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> Flow: flow = await database.flows.get(flow_id, expdb) diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index eff7081a..dc7a3f55 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -1,7 +1,6 @@ -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -14,8 +13,12 @@ ) from database.users import User from routers.dependencies import expdb_connection, fetch_user +from routers.types import Identifier from schemas.datasets.openml import Quality +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -33,7 +36,7 @@ async def list_qualities( @router.get("/qualities/{dataset_id}") async def get_qualities( - dataset_id: int, + dataset_id: Identifier, user: Annotated[User | None, Depends(fetch_user)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> list[Quality]: diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 37a7cecf..4fc79482 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,21 +1,24 @@ """Endpoints for run-related data.""" -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncConnection import database.runs from core.errors import RunNotFoundError, RunTraceNotFoundError from routers.dependencies import expdb_connection +from routers.types import Identifier from schemas.runs import RunTrace, TraceIteration +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/run", tags=["run"]) @router.get("/trace/{run_id}") async def get_run_trace( - run_id: int, + run_id: Identifier, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> RunTrace: """Get trace data for a run by run ID.""" diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py index 89abb240..dbca9038 100644 --- a/src/routers/openml/setups.py +++ b/src/routers/openml/setups.py @@ -1,11 +1,10 @@ """All endpoints that relate to setups.""" import asyncio -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Body, Depends, Path from loguru import logger -from sqlalchemy.ext.asyncio import AsyncConnection import database.setups from core.errors import ( @@ -16,15 +15,18 @@ ) from database.users import User from routers.dependencies import expdb_connection, fetch_user_or_raise -from routers.types import SystemString64 +from routers.types import Identifier, SystemString64 from schemas.setups import SetupParameters, SetupResponse +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/setup", tags=["setup"]) @router.get(path="/{setup_id}", response_model_exclude_none=True) async def get_setup( - setup_id: Annotated[int, Path()], + setup_id: Annotated[Identifier, Path()], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> SetupResponse: """Get setup by id.""" @@ -46,8 +48,8 @@ async def get_setup( @router.post(path="/tag") async def tag_setup( - setup_id: Annotated[int, Body()], - tag: Annotated[str, SystemString64], + setup_id: Annotated[Identifier, Body()], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: @@ -73,8 +75,8 @@ async def tag_setup( @router.post(path="/untag") async def untag_setup( - setup_id: Annotated[int, Body()], - tag: Annotated[str, SystemString64], + setup_id: Annotated[Identifier, Body()], + tag: Annotated[SystemString64, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 56c670c3..ccdf9ff2 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -1,10 +1,8 @@ -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Body, Depends from loguru import logger from pydantic import BaseModel -from sqlalchemy.engine import Row -from sqlalchemy.ext.asyncio import AsyncConnection import database.studies from core.errors import ( @@ -20,14 +18,19 @@ from core.formatting import _str_to_bool from database.users import User from routers.dependencies import expdb_connection, fetch_user, fetch_user_or_raise +from routers.types import Identifier from schemas.core import Visibility from schemas.study import CreateStudy, Study, StudyStatus, StudyType +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/studies", tags=["studies"]) async def _get_study_raise_otherwise( - id_or_alias: int | str, + id_or_alias: Identifier | str, user: User | None, expdb: AsyncConnection, ) -> Row: @@ -61,8 +64,8 @@ class AttachDetachResponse(BaseModel): @router.post("/attach") async def attach_to_study( - study_id: Annotated[int, Body()], - entity_ids: Annotated[list[int], Body()], + study_id: Annotated[Identifier, Body()], + entity_ids: Annotated[list[Identifier], Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: @@ -148,7 +151,7 @@ async def create_study( @router.get("/{alias_or_id}") async def get_study( - alias_or_id: int | str, + alias_or_id: Identifier | str, user: Annotated[User | None, Depends(fetch_user)] = None, expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Study: diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 3dfa5949..6627d79f 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -2,22 +2,30 @@ import json import re from enum import StrEnum -from typing import Annotated, Any, cast +from typing import TYPE_CHECKING, Annotated, Any, cast import xmltodict from fastapi import APIRouter, Body, Depends from sqlalchemy import bindparam, text -from sqlalchemy.engine import RowMapping -from sqlalchemy.ext.asyncio import AsyncConnection import config import database.datasets import database.tasks from core.errors import InternalError, NoResultsError, TaskNotFoundError from routers.dependencies import Pagination, expdb_connection -from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex +from routers.types import ( + CasualString128, + Identifier, + IntegerRange, + SystemString64, + integer_range_regex, +) from schemas.datasets.openml import Task +if TYPE_CHECKING: + from sqlalchemy.engine import RowMapping + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/tasks", tags=["tasks"]) type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None @@ -221,23 +229,23 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.get(path="/list") async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 pagination: Annotated[Pagination, Body(default_factory=Pagination)], - task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, - tag: Annotated[str | None, SystemString64] = None, - data_tag: Annotated[str | None, SystemString64] = None, + task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None, + tag: Annotated[SystemString64 | None, Body()] = None, + data_tag: Annotated[SystemString64 | None, Body()] = None, status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, task_id: Annotated[ - list[int] | None, + list[Identifier] | None, Body(description="Filter by task id(s).", min_length=1), ] = None, data_id: Annotated[ - list[int] | None, + list[Identifier] | None, Body(description="Filter by dataset id(s).", min_length=1), ] = None, - data_name: Annotated[str | None, CasualString128] = None, - number_instances: Annotated[str | None, IntegerRange] = None, - number_features: Annotated[str | None, IntegerRange] = None, - number_classes: Annotated[str | None, IntegerRange] = None, - number_missing_values: Annotated[str | None, IntegerRange] = None, + data_name: Annotated[CasualString128 | None, Body()] = None, + number_instances: Annotated[IntegerRange | None, Body()] = None, + number_features: Annotated[IntegerRange | None, Body()] = None, + number_classes: Annotated[IntegerRange | None, Body()] = None, + number_missing_values: Annotated[IntegerRange | None, Body()] = None, expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5355e451..7cdba50e 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,15 +1,17 @@ import json -from typing import Annotated, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast from fastapi import APIRouter, Depends -from sqlalchemy.engine import Row -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import TaskTypeNotFoundError from database.tasks import get_input_for_task_type, get_task_types from database.tasks import get_task_type as db_get_task_type from routers.dependencies import expdb_connection +if TYPE_CHECKING: + from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncConnection + router = APIRouter(prefix="/tasktype", tags=["tasks"]) diff --git a/src/routers/types.py b/src/routers/types.py index e107ff35..fcdb876a 100644 --- a/src/routers/types.py +++ b/src/routers/types.py @@ -1,18 +1,18 @@ from typing import Annotated -from fastapi import Body from pydantic import Field -SystemString64 = Body(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64) - -CasualString128 = Body(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128) - +SystemString64 = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)] +CasualString128 = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)] Identifier = Annotated[int, Field(gt=0)] integer_range_regex = r"^(\d+)(\.\.\d+)?$" -IntegerRange = Body( - pattern=integer_range_regex, - description="Either a single integer, or a range defined as `low..high`, where" - "`low` and `high` are inclusive integer bounds of the range.", - examples=["12", "3..150"], -) +IntegerRange = Annotated[ + str, + Field( + pattern=integer_range_regex, + description="Either a single integer, or a range defined as `low..high`, where" + "`low` and `high` are inclusive integer bounds of the range.", + examples=["12", "3..150"], + ), +] diff --git a/tests/conftest.py b/tests/conftest.py index 368b789b..01ac8c82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,23 +2,25 @@ import json from collections.abc import AsyncIterator, Callable, Iterable, Iterator from pathlib import Path -from typing import Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import _pytest.mark import httpx import pytest -from _pytest.config import Config -from _pytest.nodes import Item +from _pytest.config import Config # noqa: TC002 used during collection by Pytest +from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest from asgi_lifespan import LifespanManager -from fastapi import FastAPI from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from database.setup import expdb_database, user_database from main import create_api from routers.dependencies import expdb_connection, userdb_connection from tests.users import OWNER_USER +if TYPE_CHECKING: + from fastapi import FastAPI + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + PHP_API_URL = "http://php-api:80/api/v1/json" @@ -175,19 +177,27 @@ def temporary_tags( ) -> Callable[..., contextlib.AbstractAsyncContextManager[None]]: @contextlib.asynccontextmanager async def _temporary_tags( - tags: Iterable[str], setup_id: int, *, persist: bool = False + table: str, + tags: Iterable[str], + identifier: int, + *, + persist: bool = False, ) -> AsyncIterator[None]: insert_queries = [ ( - "INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);", - {"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id}, + f"INSERT INTO {table}(`id`,`tag`,`uploader`) VALUES (:identifier, :tag, :user_id);", # noqa: S608 # No user provided values + { + "identifier": identifier, + "tag": tag, + "user_id": OWNER_USER.user_id, + }, ) for tag in tags ] delete_queries = [ ( - "DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag", - {"setup_id": setup_id, "tag": tag}, + f"DELETE FROM {table} WHERE `id`=:identifier AND `tag`=:tag", # noqa: S608 # No user provided values + {"identifier": identifier, "tag": tag}, ) for tag in tags ] diff --git a/tests/database/flows_test.py b/tests/database/flows_test.py index a8b98d84..6e345773 100644 --- a/tests/database/flows_test.py +++ b/tests/database/flows_test.py @@ -1,8 +1,11 @@ -from sqlalchemy.ext.asyncio import AsyncConnection +from typing import TYPE_CHECKING import database.flows from tests.conftest import Flow +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_database_flow_exists(flow: Flow, expdb_test: AsyncConnection) -> None: retrieved_flow = await database.flows.get_by_name(flow.name, flow.external_version, expdb_test) diff --git a/tests/dependencies/fetch_user_test.py b/tests/dependencies/fetch_user_test.py index 116bbdd9..c9bfa08c 100644 --- a/tests/dependencies/fetch_user_test.py +++ b/tests/dependencies/fetch_user_test.py @@ -1,13 +1,16 @@ from contextlib import aclosing +from typing import TYPE_CHECKING import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.users import User from routers.dependencies import fetch_user, fetch_user_or_raise from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + @pytest.mark.parametrize( ("api_key", "user"), diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 79c982f3..5bcfb4a3 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,9 +1,8 @@ import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.conversions import nested_remove_single_element_list from core.errors import DatasetNotFoundError, TagAlreadyExistsError @@ -13,6 +12,10 @@ from tests import constants from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + @pytest.mark.parametrize( "key", diff --git a/tests/routers/openml/dataset_untag_test.py b/tests/routers/openml/dataset_untag_test.py new file mode 100644 index 00000000..cfaaf249 --- /dev/null +++ b/tests/routers/openml/dataset_untag_test.py @@ -0,0 +1,222 @@ +"""Tests for untagging a dataset. + +There are currently two endpoints for untagging a dataset: + + POST /datasets/untag + DEL /datasets/{id}/tag + +The former is provided for compatibility with the old API, and is tested in the migration test. +The latter is more semantically correct, and is used for the Python tests. +They share most of the underlying logic anyway. +""" + +import asyncio +from collections.abc import Callable +from contextlib import AbstractAsyncContextManager +from http import HTTPStatus +from typing import TYPE_CHECKING + +import pytest +from sqlalchemy import text + +from core.errors import DatasetNotFoundError, TagNotFoundError, TagNotOwnedError +from routers.openml.datasets import untag_dataset +from tests.users import ADMIN_USER, SOME_USER, ApiKey + +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + + +async def test_dataset_untag_success( + py_api: httpx.AsyncClient, expdb_test: AsyncConnection +) -> None: + dataset_id = 1 + tag = "foo" + await expdb_test.execute( + text("INSERT INTO dataset_tag(id, tag, uploader) VALUES (:dataset_id, :tag, 2)"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + response = await py_api.delete( + f"/datasets/{dataset_id}/tag?api_key={ApiKey.SOME_USER}&tag={tag}", + ) + + assert response.status_code == HTTPStatus.NO_CONTENT + tag_present = await expdb_test.execute( + text("SELECT 1 FROM dataset_tag WHERE id=:dataset_id AND tag=:tag"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + assert tag_present.scalar() is None + + +async def test_dataset_untag_tag_does_not_exist(expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "foo" + + with pytest.raises(TagNotFoundError) as e: + await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + + assert e.value.status_code == HTTPStatus.NOT_FOUND + assert tag in e.value.detail + assert str(dataset_id) in e.value.detail + + +async def test_dataset_untag_tag_not_owned(expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "foo" + await expdb_test.execute( + text("INSERT INTO dataset_tag(id, tag, uploader) VALUES (:dataset_id, :tag, 1)"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + with pytest.raises(TagNotOwnedError) as e: + await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert tag in e.value.detail + assert str(dataset_id) in e.value.detail + + tag_present = await expdb_test.execute( + text("SELECT 1 FROM dataset_tag WHERE id=:dataset_id AND tag=:tag"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + assert tag_present.scalar() == 1 + + +async def test_dataset_untag_admin_bypasses_ownership(expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "foo" + await expdb_test.execute( + text("INSERT INTO dataset_tag(id, tag, uploader) VALUES (:dataset_id, :tag, 1)"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + + await untag_dataset(dataset_id, tag, ADMIN_USER, expdb_test) + + tag_present = await expdb_test.execute( + text("SELECT 1 FROM dataset_tag WHERE id=:dataset_id AND tag=:tag"), + parameters={"dataset_id": dataset_id, "tag": tag}, + ) + assert tag_present.scalar() is None + + +async def test_dataset_untag_dataset_does_not_exist(expdb_test: AsyncConnection) -> None: + dataset_id = 9_999_999 + tag = "foo" + + with pytest.raises(DatasetNotFoundError) as e: + await untag_dataset(dataset_id, tag, SOME_USER, expdb_test) + + assert e.value.status_code == HTTPStatus.NOT_FOUND + assert tag in e.value.detail + assert str(dataset_id) in e.value.detail + + +@pytest.mark.mut +@pytest.mark.parametrize("existing_tags", [[], ["bar"], ["bar", "bazz"]]) +async def test_dataset_untag_success_is_identical( + existing_tags: list[str], + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, + temporary_tags: Callable[..., AbstractAsyncContextManager[None]], +) -> None: + dataset_id = 101 # The first dataset without a pre-existing tag + tag = "foo" + + async with temporary_tags( + table="dataset_tag", tags=[tag, *existing_tags], identifier=dataset_id, persist=True + ): + php_response = await php_api.post( + f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} + ) + + async with temporary_tags( + table="dataset_tag", tags=[tag, *existing_tags], identifier=dataset_id + ): + py_response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} + ) + + assert py_response.status_code == php_response.status_code + assert py_response.json() == php_response.json() + + +@pytest.mark.mut +async def test_dataset_untag_tag_does_not_exist_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + dataset_id = 1 + tag = "foo" + + py_response, php_response = await asyncio.gather( + py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} + ), + php_api.post( + f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} + ), + ) + + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Tag not found." + assert tag in py_response.json()["detail"] + assert str(dataset_id) in py_response.json()["detail"] + assert "not found" in py_response.json()["detail"] + + +@pytest.mark.mut +async def test_dataset_untag_dataset_does_not_exist_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + dataset_id = 9_999_999 + tag = "foo" + + py_response, php_response = await asyncio.gather( + py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", json={"tag": tag, "data_id": dataset_id} + ), + php_api.post( + f"/data/untag?api_key={ApiKey.OWNER_USER}", data={"tag": tag, "data_id": dataset_id} + ), + ) + + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Entity not found." + assert tag in py_response.json()["detail"] + assert str(dataset_id) in py_response.json()["detail"] + assert "not found" in py_response.json()["detail"] + + +@pytest.mark.mut +async def test_dataset_untag_tag_not_owned_is_identical( + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, + temporary_tags: Callable[..., AbstractAsyncContextManager[None]], +) -> None: + dataset_id = 1 + tag = "foo" + + async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id, persist=True): + php_response = await php_api.post( + f"/data/untag?api_key={ApiKey.SOME_USER}", data={"tag": tag, "data_id": dataset_id} + ) + + async with temporary_tags(table="dataset_tag", tags=[tag], identifier=dataset_id): + py_response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", json={"tag": tag, "data_id": dataset_id} + ) + + assert py_response.status_code == HTTPStatus.FORBIDDEN + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Tag is not owned by you" + assert tag in py_response.json()["detail"] + assert str(dataset_id) in py_response.json()["detail"] + assert "not allowed" in py_response.json()["detail"] diff --git a/tests/routers/openml/datasets_features_test.py b/tests/routers/openml/datasets_features_test.py index 1fd8985f..36166596 100644 --- a/tests/routers/openml/datasets_features_test.py +++ b/tests/routers/openml/datasets_features_test.py @@ -3,16 +3,19 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import DatasetNoAccessError, DatasetNotFoundError, DatasetProcessingError from database.users import User from routers.openml.datasets import get_dataset_features from tests.users import ADMIN_USER, DATASET_130_OWNER +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_get_features_via_api(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/features/4") diff --git a/tests/routers/openml/datasets_get_test.py b/tests/routers/openml/datasets_get_test.py index 4b9fb33c..3ef034eb 100644 --- a/tests/routers/openml/datasets_get_test.py +++ b/tests/routers/openml/datasets_get_test.py @@ -4,11 +4,10 @@ import json import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection import tests.constants from core.errors import DatasetNoAccessError, DatasetNotFoundError @@ -17,6 +16,10 @@ from schemas.datasets.openml import DatasetMetadata from tests.users import ADMIN_USER, DATASET_130_OWNER, NO_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_get_dataset_via_api(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/1") @@ -89,20 +92,24 @@ async def test_dataset_no_500_with_multiple_processing_entries( @pytest.mark.parametrize( "dataset_id", - [-1, 138, 100_000], + [138, 100_000], ) async def test_get_dataset_not_found( dataset_id: int, expdb_test: AsyncConnection, user_test: AsyncConnection, ) -> None: - with pytest.raises(DatasetNotFoundError): + with pytest.raises(DatasetNotFoundError) as exc_info: await get_dataset( dataset_id=dataset_id, user=None, user_db=user_test, expdb_db=expdb_test, ) + assert exc_info.value.status_code == HTTPStatus.NOT_FOUND + _dataset_get_not_found_code = 111 + assert exc_info.value.code == _dataset_get_not_found_code + assert exc_info.value.detail.startswith("No dataset") @pytest.mark.parametrize( @@ -235,24 +242,21 @@ async def test_dataset_response_is_identical( # noqa: C901, PLR0912 assert py_json == php_json -@pytest.mark.parametrize( - "dataset_id", - [-1, 138, 100_000], -) -async def test_error_unknown_dataset( - dataset_id: int, +async def test_dataset_not_found_is_identical( py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, ) -> None: - response = await py_api.get(f"/datasets/{dataset_id}") + dataset_id = 9_999_999 + py_response, php_response = await asyncio.gather( + py_api.get(f"/datasets/{dataset_id}"), + php_api.get(f"/data/{dataset_id}"), + ) - # The new API has "404 Not Found" instead of "412 PRECONDITION_FAILED" - assert response.status_code == HTTPStatus.NOT_FOUND - # RFC 9457: Python API now returns problem+json format - assert response.headers["content-type"] == "application/problem+json" - error = response.json() - assert error["code"] == "111" - # instead of 'Unknown dataset' - assert error["detail"].startswith("No dataset") + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert py_response.json()["detail"] == f"No dataset with id {dataset_id} found." + assert php_response.json()["error"]["message"] == "Unknown dataset" async def test_private_dataset_no_user_no_access( diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index 7460fb52..868e9ee6 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -1,13 +1,12 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any -import httpx +import httpx # noqa: TC002 is used in a function signature inspected at runtime import hypothesis import pytest from hypothesis import given from hypothesis import strategies as st -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import NoResultsError from database.users import User @@ -16,6 +15,9 @@ from tests import constants from tests.users import ADMIN_USER, DATASET_130_OWNER, SOME_USER, ApiKey +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_list_route(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/list/") diff --git a/tests/routers/openml/datasets_qualities_test.py b/tests/routers/openml/datasets_qualities_test.py index fb3559ce..d67e54a9 100644 --- a/tests/routers/openml/datasets_qualities_test.py +++ b/tests/routers/openml/datasets_qualities_test.py @@ -1,11 +1,14 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff -import httpx import pytest +if TYPE_CHECKING: + import httpx + async def test_get_quality(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/datasets/qualities/1") diff --git a/tests/routers/openml/datasets_status_test.py b/tests/routers/openml/datasets_status_test.py index 1e2271fc..adc5892e 100644 --- a/tests/routers/openml/datasets_status_test.py +++ b/tests/routers/openml/datasets_status_test.py @@ -1,10 +1,9 @@ """Tests for the POST /datasets/status/update endpoint.""" from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import DatasetAdminOnlyError, DatasetNotOwnedError from routers.openml.datasets import update_dataset_status @@ -12,6 +11,10 @@ from tests import constants from tests.users import ADMIN_USER, SOME_USER +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_update_status_via_api(py_api: httpx.AsyncClient) -> None: response = await py_api.post( diff --git a/tests/routers/openml/estimation_procedure_test.py b/tests/routers/openml/estimation_procedure_test.py index a05b34dc..dafa35c4 100644 --- a/tests/routers/openml/estimation_procedure_test.py +++ b/tests/routers/openml/estimation_procedure_test.py @@ -1,8 +1,9 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any -import httpx +if TYPE_CHECKING: + import httpx async def test_estimation_procedure_list(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/evaluation_measure_test.py b/tests/routers/openml/evaluation_measure_test.py index 2df2483f..feecb6b4 100644 --- a/tests/routers/openml/evaluation_measure_test.py +++ b/tests/routers/openml/evaluation_measure_test.py @@ -1,7 +1,9 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx +if TYPE_CHECKING: + import httpx async def test_evaluationmeasure_list(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/flows_exists_test.py b/tests/routers/openml/flows_exists_test.py index bb09edd6..988402f3 100644 --- a/tests/routers/openml/flows_exists_test.py +++ b/tests/routers/openml/flows_exists_test.py @@ -1,16 +1,19 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest -from pytest_mock import MockerFixture -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import FlowNotFoundError from routers.openml.flows import flow_exists from tests.conftest import Flow +if TYPE_CHECKING: + import httpx + from pytest_mock import MockerFixture + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None: response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") diff --git a/tests/routers/openml/flows_get_test.py b/tests/routers/openml/flows_get_test.py index 17bbfcca..862bb702 100644 --- a/tests/routers/openml/flows_get_test.py +++ b/tests/routers/openml/flows_get_test.py @@ -1,9 +1,8 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any import deepdiff.diff -import httpx import pytest from core.conversions import ( @@ -11,6 +10,9 @@ nested_str_to_num, ) +if TYPE_CHECKING: + import httpx + async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/flows/1") diff --git a/tests/routers/openml/qualities_list_test.py b/tests/routers/openml/qualities_list_test.py index 8eb51a58..84cb94b2 100644 --- a/tests/routers/openml/qualities_list_test.py +++ b/tests/routers/openml/qualities_list_test.py @@ -1,10 +1,13 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection + +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection async def _remove_quality_from_database(quality_name: str, expdb_test: AsyncConnection) -> None: diff --git a/tests/routers/openml/runs_trace_test.py b/tests/routers/openml/runs_trace_test.py index 11fd10ac..a9e664c6 100644 --- a/tests/routers/openml/runs_trace_test.py +++ b/tests/routers/openml/runs_trace_test.py @@ -2,15 +2,17 @@ import asyncio from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any import deepdiff -import httpx import pytest from core.conversions import nested_num_to_str from core.errors import RunNotFoundError, RunTraceNotFoundError +if TYPE_CHECKING: + import httpx + @pytest.mark.parametrize("run_id", [34]) async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/setups_get_test.py b/tests/routers/openml/setups_get_test.py index 6762714f..646fa0cc 100644 --- a/tests/routers/openml/setups_get_test.py +++ b/tests/routers/openml/setups_get_test.py @@ -1,12 +1,15 @@ import asyncio import re from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from core.conversions import nested_remove_values, nested_str_to_num +if TYPE_CHECKING: + import httpx + async def test_get_setup_unknown(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/setup/999999") diff --git a/tests/routers/openml/setups_tag_test.py b/tests/routers/openml/setups_tag_test.py index ad9659f6..c674ca2c 100644 --- a/tests/routers/openml/setups_tag_test.py +++ b/tests/routers/openml/setups_tag_test.py @@ -3,16 +3,19 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import SetupNotFoundError, TagAlreadyExistsError from routers.openml.setups import tag_setup from tests.users import SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_setup_tag_missing_auth(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/setup/tag", json={"setup_id": 1, "tag": "test_tag"}) @@ -111,7 +114,9 @@ async def test_setup_tag_response_is_identical_when_tag_doesnt_exist( # noqa: P setup_id = 1 tag = "totally_new_tag_for_migration_testing" - async with temporary_tags(tags=other_tags, setup_id=setup_id, persist=True): + async with temporary_tags( + table="setup_tag", tags=other_tags, identifier=setup_id, persist=True + ): php_response = await php_api.post( "/setup/tag", data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, @@ -123,7 +128,7 @@ async def test_setup_tag_response_is_identical_when_tag_doesnt_exist( # noqa: P ) await expdb_test.commit() - async with temporary_tags(tags=other_tags, setup_id=setup_id): + async with temporary_tags(table="setup_tag", tags=other_tags, identifier=setup_id): py_response = await py_api.post( f"/setup/tag?api_key={api_key}", json={"setup_id": setup_id, "tag": tag}, @@ -182,7 +187,7 @@ async def test_setup_tag_response_is_identical_tag_already_exists( tag = "totally_new_tag_for_migration_testing" api_key = ApiKey.SOME_USER - async with temporary_tags(tags=[tag], setup_id=setup_id, persist=True): + async with temporary_tags(table="setup_tag", tags=[tag], identifier=setup_id, persist=True): # Both APIs can be tested in parallel since the tag is already persisted php_response, py_response = await asyncio.gather( php_api.post( diff --git a/tests/routers/openml/setups_untag_test.py b/tests/routers/openml/setups_untag_test.py index 1ed7b42e..1491fd1f 100644 --- a/tests/routers/openml/setups_untag_test.py +++ b/tests/routers/openml/setups_untag_test.py @@ -3,16 +3,19 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import SetupNotFoundError, TagNotFoundError, TagNotOwnedError from routers.openml.setups import untag_setup from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_setup_untag_missing_auth(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/setup/untag", json={"setup_id": 1, "tag": "test_tag"}) @@ -144,7 +147,7 @@ async def test_setup_untag_response_is_identical_when_tag_exists( tag = "totally_new_tag_for_migration_testing" all_tags = [tag, *other_tags] - async with temporary_tags(tags=all_tags, setup_id=setup_id, persist=True): + async with temporary_tags(table="setup_tag", tags=all_tags, identifier=setup_id, persist=True): php_response = await php_api.post( "/setup/untag", data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, @@ -152,7 +155,7 @@ async def test_setup_untag_response_is_identical_when_tag_exists( # expdb_test transaction shared with Python API, # no commit needed and rolled back at the end of the test - async with temporary_tags(tags=all_tags, setup_id=setup_id): + async with temporary_tags(table="setup_tag", tags=all_tags, identifier=setup_id): py_response = await py_api.post( f"/setup/untag?api_key={api_key}", json={"setup_id": setup_id, "tag": tag}, diff --git a/tests/routers/openml/study_attach_test.py b/tests/routers/openml/study_attach_test.py index 2da1b8f0..dcb68704 100644 --- a/tests/routers/openml/study_attach_test.py +++ b/tests/routers/openml/study_attach_test.py @@ -1,14 +1,17 @@ from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import StudyConflictError from schemas.study import StudyType from tests.users import ApiKey +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def _attach_tasks_to_study( study_id: int, diff --git a/tests/routers/openml/study_get_test.py b/tests/routers/openml/study_get_test.py index 1ef2cff1..0762633a 100644 --- a/tests/routers/openml/study_get_test.py +++ b/tests/routers/openml/study_get_test.py @@ -1,11 +1,14 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff -import httpx from core.conversions import nested_num_to_str, nested_remove_values +if TYPE_CHECKING: + import httpx + async def test_get_task_study_by_id(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/studies/1") diff --git a/tests/routers/openml/study_post_test.py b/tests/routers/openml/study_post_test.py index 0cb00fdb..3c3c5e9c 100644 --- a/tests/routers/openml/study_post_test.py +++ b/tests/routers/openml/study_post_test.py @@ -1,11 +1,14 @@ from datetime import UTC, datetime from http import HTTPStatus +from typing import TYPE_CHECKING -import httpx import pytest from tests.users import ApiKey +if TYPE_CHECKING: + import httpx + @pytest.mark.mut async def test_create_task_study(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/task_get_test.py b/tests/routers/openml/task_get_test.py index 955a7b81..7225061a 100644 --- a/tests/routers/openml/task_get_test.py +++ b/tests/routers/openml/task_get_test.py @@ -1,8 +1,8 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff -import httpx import pytest from core.conversions import ( @@ -11,6 +11,9 @@ nested_remove_values, ) +if TYPE_CHECKING: + import httpx + async def test_get_task(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/tasks/59") diff --git a/tests/routers/openml/task_list_test.py b/tests/routers/openml/task_list_test.py index 45404d1b..7a3f562f 100644 --- a/tests/routers/openml/task_list_test.py +++ b/tests/routers/openml/task_list_test.py @@ -1,17 +1,19 @@ import asyncio from http import HTTPStatus -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import deepdiff -import httpx import pytest -from sqlalchemy.ext.asyncio import AsyncConnection from core.conversions import nested_remove_single_element_list from core.errors import NoResultsError from routers.dependencies import LIMIT_MAX, Pagination from routers.openml.tasks import TaskStatusFilter, list_tasks +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: """Default call returns active tasks with correct shape.""" diff --git a/tests/routers/openml/task_type_get_test.py b/tests/routers/openml/task_type_get_test.py index 61bd0c91..b186da18 100644 --- a/tests/routers/openml/task_type_get_test.py +++ b/tests/routers/openml/task_type_get_test.py @@ -1,12 +1,15 @@ import asyncio from http import HTTPStatus +from typing import TYPE_CHECKING import deepdiff.diff -import httpx import pytest from core.errors import TaskTypeNotFoundError +if TYPE_CHECKING: + import httpx + @pytest.mark.parametrize( "ttype_id", diff --git a/tests/routers/openml/task_type_list_test.py b/tests/routers/openml/task_type_list_test.py index 871def3b..ee61e1b0 100644 --- a/tests/routers/openml/task_type_list_test.py +++ b/tests/routers/openml/task_type_list_test.py @@ -1,6 +1,8 @@ import asyncio +from typing import TYPE_CHECKING -import httpx +if TYPE_CHECKING: + import httpx async def test_list_task_type(py_api: httpx.AsyncClient, php_api: httpx.AsyncClient) -> None: