Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [
]
description = "The Python-based REST API for OpenML."
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.14"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down
7 changes: 4 additions & 3 deletions src/core/access.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any

from sqlalchemy.engine import Row
from typing import TYPE_CHECKING, Any

from database.users import User
from schemas.datasets.openml import Visibility

if TYPE_CHECKING:
from sqlalchemy.engine import Row


async def _user_has_access(
dataset: Row[Any],
Expand Down
7 changes: 5 additions & 2 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
"""

from http import HTTPStatus
from typing import TYPE_CHECKING

from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse

if TYPE_CHECKING:
from fastapi import Request
from fastapi.exceptions import RequestValidationError

# =============================================================================
# Base Exception
# =============================================================================
Expand Down
6 changes: 4 additions & 2 deletions src/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import html

from sqlalchemy.engine import Row
from typing import TYPE_CHECKING

from config import load_routing_configuration
from schemas.datasets.openml import DatasetFileFormat

if TYPE_CHECKING:
from sqlalchemy.engine import Row


def _str_to_bool(string: str) -> bool:
if string.casefold() in ["true", "1", "yes", "y"]:
Expand Down
7 changes: 5 additions & 2 deletions src/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import TYPE_CHECKING

from loguru import logger
from starlette.requests import Request
from starlette.responses import Response

from config import load_configuration

if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response


def setup_log_sinks(configuration_file: Path | None = None) -> None:
"""Configure loguru based on app configuration."""
Expand Down
34 changes: 32 additions & 2 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import datetime
from collections import defaultdict
from typing import TYPE_CHECKING

from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection

from database.exceptions import (
_DUPLICATE_ENTRY,
Expand All @@ -16,6 +15,10 @@
)
from schemas.datasets.openml import Feature

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(id_: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
Expand Down Expand Up @@ -45,6 +48,33 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
return row.one_or_none()


async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row | None:
return (
await connection.execute(
text(
"""
SELECT *
FROM dataset_tag
WHERE id = :dataset_id AND tag = :tag
""",
),
parameters={"dataset_id": dataset_id, "tag": tag},
)
).first()


async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
DELETE FROM dataset_tag
WHERE id = :dataset_id AND tag = :tag
""",
),
parameters={"dataset_id": dataset_id, "tag": tag},
)


async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:
row = await connection.execute(
text(
Expand Down
6 changes: 4 additions & 2 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections.abc import Sequence
from typing import cast
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
rows = await connection.execute(
Expand Down
6 changes: 4 additions & 2 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections.abc import Sequence
from typing import cast
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
Expand Down
5 changes: 4 additions & 1 deletion src/database/qualities.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection

from schemas.datasets.openml import Quality

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]:
row = await connection.execute(
Expand Down
6 changes: 4 additions & 2 deletions src/database/runs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Database queries for run-related data."""

from collections.abc import Sequence
from typing import cast
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def exist(id_: int, expdb: AsyncConnection) -> bool:
Expand Down
8 changes: 6 additions & 2 deletions src/database/setups.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""All database operations that directly operate on setups."""

from typing import TYPE_CHECKING

from sqlalchemy import text
from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection

if TYPE_CHECKING:
from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
Expand Down
6 changes: 4 additions & 2 deletions src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import re
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import cast
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.users import User
from schemas.study import CreateStudy, StudyType

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
Expand Down
6 changes: 4 additions & 2 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections.abc import Sequence
from typing import cast
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
Expand Down
6 changes: 4 additions & 2 deletions src/database/users.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import dataclasses
from enum import IntEnum
from typing import Annotated, Self
from typing import TYPE_CHECKING, Annotated, Self

from pydantic import StringConstraints
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection

from config import load_configuration

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection

# If `allow_test_api_keys` is set, the key may also be one of `normaluser`,
# `normaluser2`, or `abc` (admin).
api_key_pattern = r"^[0-9a-fA-F]{32}$"
Expand Down
6 changes: 4 additions & 2 deletions src/routers/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends
from loguru import logger
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import AuthenticationFailedError, AuthenticationRequiredError
from database.setup import expdb_database, user_database
from database.users import APIKey, User

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def expdb_connection() -> AsyncIterator[AsyncConnection]:
engine = expdb_database()
Expand Down
6 changes: 4 additions & 2 deletions src/routers/mldcat_ap/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
"""

import asyncio
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncConnection

import config
from core.errors import ServiceNotFoundError
Expand All @@ -28,6 +27,9 @@
Quality,
)

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection

router = APIRouter(prefix="/mldcat_ap", tags=["MLDCAT-AP"])
_configuration = config.load_configuration()
_server_url = (
Expand Down
Loading
Loading