Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
Expand All @@ -13,14 +13,15 @@
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
from routers.types import Identifier, TagString
from schemas.datasets.openml import DatasetStatus, Feature

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(id_: int, connection: AsyncConnection) -> Row | None:
async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
Expand All @@ -34,7 +35,7 @@ async def get(id_: int, connection: AsyncConnection) -> Row | None:
return row.one_or_none()


async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
Expand All @@ -48,7 +49,11 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
return row.one_or_none()


async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row | None:
async def get_tag(
dataset_id: Identifier,
tag: TagString,
connection: AsyncConnection,
) -> Row | None:
return (
await connection.execute(
text(
Expand All @@ -63,7 +68,7 @@ async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row
).first()


async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None:
async def delete_tag(dataset_id: Identifier, tag: TagString, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
Expand All @@ -75,7 +80,7 @@ async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) ->
)


async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:
async def get_tags_for(id_: Identifier, connection: AsyncConnection) -> list[str]:
row = await connection.execute(
text(
"""
Expand Down Expand Up @@ -115,7 +120,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)


async def get_description(
id_: int,
id_: Identifier,
connection: AsyncConnection,
) -> Row | None:
"""Get the most recent description for the dataset."""
Expand All @@ -133,7 +138,7 @@ async def get_description(
return row.first()


async def get_status(id_: int, connection: AsyncConnection) -> DatasetStatus:
async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetStatus:
"""Get most recent status for the dataset."""
row = (
await connection.execute(
Expand All @@ -152,7 +157,10 @@ async def get_status(id_: int, connection: AsyncConnection) -> DatasetStatus:
return DatasetStatus(row.status) if row else DatasetStatus.IN_PREPARATION


async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None:
async def get_latest_processing_update(
dataset_id: Identifier,
connection: AsyncConnection,
) -> Row | None:
row = await connection.execute(
text(
"""
Expand All @@ -167,7 +175,7 @@ async def get_latest_processing_update(dataset_id: int, connection: AsyncConnect
return row.first()


async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]:
async def get_features(dataset_id: Identifier, connection: AsyncConnection) -> list[Feature]:
row = await connection.execute(
text(
"""
Expand All @@ -184,7 +192,7 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea


async def get_feature_ontologies(
dataset_id: int,
dataset_id: Identifier,
connection: AsyncConnection,
) -> dict[int, list[str]]:
rows = await connection.execute(
Expand All @@ -204,7 +212,7 @@ async def get_feature_ontologies(


async def get_feature_values(
dataset_id: int,
dataset_id: Identifier,
*,
feature_index: int,
connection: AsyncConnection,
Expand All @@ -224,10 +232,10 @@ async def get_feature_values(


async def update_status(
dataset_id: int,
status: str,
dataset_id: Identifier,
status: Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED],
*,
user_id: int,
user_id: Identifier,
connection: AsyncConnection,
) -> None:
await connection.execute(
Expand All @@ -246,7 +254,7 @@ async def update_status(
)


async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None:
async def remove_deactivated_status(dataset_id: Identifier, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
Expand Down
10 changes: 6 additions & 4 deletions src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from sqlalchemy import Row, text

from routers.types import Identifier

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
Expand All @@ -24,7 +26,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
)


async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
Expand All @@ -39,7 +41,7 @@ async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
return [tag.tag for tag in tag_rows]


async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
Expand Down Expand Up @@ -71,7 +73,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection)
return row.one_or_none()


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
Expand Down
5 changes: 3 additions & 2 deletions src/database/qualities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

from sqlalchemy import text

from routers.types import Identifier
from schemas.datasets.openml import Quality

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]:
async def get_for_dataset(dataset_id: Identifier, connection: AsyncConnection) -> list[Quality]:
row = await connection.execute(
text(
"""
Expand All @@ -26,7 +27,7 @@ async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[


async def get_for_datasets(
dataset_ids: Iterable[int],
dataset_ids: Iterable[Identifier],
quality_names: Iterable[str],
connection: AsyncConnection,
) -> dict[int, list[Quality]]:
Expand Down
6 changes: 4 additions & 2 deletions src/database/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from sqlalchemy import Row, text

from routers.types import Identifier

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def exist(id_: int, expdb: AsyncConnection) -> bool:
async def exist(id_: Identifier, expdb: AsyncConnection) -> bool:
"""Check if a run exists by ID."""
row = await expdb.execute(
text(
Expand All @@ -24,7 +26,7 @@ async def exist(id_: int, expdb: AsyncConnection) -> bool:
return bool(row.one_or_none())


async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
async def get_trace(run_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
text(
Expand Down
17 changes: 12 additions & 5 deletions src/database/setups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from sqlalchemy import text

from routers.types import Identifier, TagString

if TYPE_CHECKING:
from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
async def get(setup_id: Identifier, connection: AsyncConnection) -> Row | None:
"""Get the setup with id `setup_id` from the database."""
row = await connection.execute(
text(
Expand All @@ -24,7 +26,7 @@ async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
return row.first()


async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[RowMapping]:
async def get_parameters(setup_id: Identifier, connection: AsyncConnection) -> list[RowMapping]:
"""Get all parameters for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
Expand All @@ -51,7 +53,7 @@ async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[Row
return list(rows.mappings().all())


async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[Row]:
"""Get all tags for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
Expand All @@ -66,7 +68,7 @@ async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
return list(rows.all())


async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
async def untag(setup_id: Identifier, tag: TagString, connection: AsyncConnection) -> None:
"""Remove tag `tag` from setup with id `setup_id`."""
await connection.execute(
text(
Expand All @@ -79,7 +81,12 @@ async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
)


async def tag(setup_id: int, tag: str, user_id: int, connection: AsyncConnection) -> None:
async def tag(
setup_id: Identifier,
tag: TagString,
user_id: Identifier,
connection: AsyncConnection,
) -> None:
"""Add tag `tag` to setup with id `setup_id`."""
await connection.execute(
text(
Expand Down
22 changes: 17 additions & 5 deletions src/database/studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from sqlalchemy import Row, text

from database.users import User
from routers.types import Identifier
from schemas.study import CreateStudy, StudyType

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None:
async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
Expand Down Expand Up @@ -114,7 +115,12 @@ async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int:
return cast("int", study_id)


async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None:
async def attach_task(
task_id: Identifier,
study_id: Identifier,
user: User,
expdb: AsyncConnection,
) -> None:
await expdb.execute(
text(
"""
Expand All @@ -126,7 +132,13 @@ async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConne
)


async def attach_run(*, run_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None:
async def attach_run(
*,
run_id: Identifier,
study_id: Identifier,
user: User,
expdb: AsyncConnection,
) -> None:
await expdb.execute(
text(
"""
Expand Down Expand Up @@ -171,8 +183,8 @@ async def attach_tasks(


async def attach_runs(
study_id: int,
run_ids: list[int],
study_id: Identifier,
run_ids: list[Identifier],
user: User,
connection: AsyncConnection,
) -> None:
Expand Down
Loading
Loading