From e0f3576e0d078fef4d5931a35e569525f34cfffd Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 09:57:12 +0200
Subject: [PATCH 1/7] Rename SystemString64 to TagString and test it separately
---
src/routers/openml/datasets.py | 12 ++++----
src/routers/openml/setups.py | 6 ++--
src/routers/openml/tasks.py | 6 ++--
src/routers/types.py | 4 ++-
tests/types_test.py | 50 ++++++++++++++++++++++++++++++++++
5 files changed, 65 insertions(+), 13 deletions(-)
create mode 100644 tests/types_test.py
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index 045f4e50..270c71b5 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -46,7 +46,7 @@
CasualString128,
Identifier,
IntegerRange,
- SystemString64,
+ TagString,
integer_range_regex,
)
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType
@@ -63,7 +63,7 @@
)
async def tag_dataset(
data_id: Annotated[Identifier, Body()],
- tag: Annotated[SystemString64, Body()],
+ tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
@@ -87,13 +87,13 @@ async def tag_dataset(
class TagInfo(TypedDict):
id: str
- tag: NotRequired[SystemString64 | list[SystemString64]]
+ tag: NotRequired[TagString | list[TagString]]
@router.post(path="/untag", deprecated=True)
async def untag_dataset_like_php(
data_id: Annotated[Identifier, Body()],
- tag: Annotated[SystemString64, Body()],
+ tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["data_untag"], TagInfo]:
@@ -110,7 +110,7 @@ async def untag_dataset_like_php(
@router.delete(path="/{identifier}/tag", status_code=HTTPStatus.NO_CONTENT)
async def untag_dataset(
identifier: Identifier,
- tag: Annotated[SystemString64, Query()],
+ tag: Annotated[TagString, Query()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> None:
@@ -158,7 +158,7 @@ def _quality_clause(quality: str, range_: str | None) -> str:
async def list_datasets( # noqa: PLR0913, C901
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[CasualString128 | None, Body()] = None,
- tag: Annotated[SystemString64 | None, Body()] = None,
+ tag: Annotated[TagString | None, Body()] = None,
data_version: Annotated[
Identifier | None,
Body(description="The dataset version to include in the search."),
diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py
index dbca9038..ef71ce61 100644
--- a/src/routers/openml/setups.py
+++ b/src/routers/openml/setups.py
@@ -15,7 +15,7 @@
)
from database.users import User
from routers.dependencies import expdb_connection, fetch_user_or_raise
-from routers.types import Identifier, SystemString64
+from routers.types import Identifier, TagString
from schemas.setups import SetupParameters, SetupResponse
if TYPE_CHECKING:
@@ -49,7 +49,7 @@ async def get_setup(
@router.post(path="/tag")
async def tag_setup(
setup_id: Annotated[Identifier, Body()],
- tag: Annotated[SystemString64, Body()],
+ tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
@@ -76,7 +76,7 @@ async def tag_setup(
@router.post(path="/untag")
async def untag_setup(
setup_id: Annotated[Identifier, Body()],
- tag: Annotated[SystemString64, Body()],
+ tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 411ea5b5..8faa898c 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -17,7 +17,7 @@
CasualString128,
Identifier,
IntegerRange,
- SystemString64,
+ TagString,
integer_range_regex,
)
from schemas.datasets.openml import Task
@@ -231,8 +231,8 @@ def _quality_clause(quality: str, range_: str | None) -> str:
async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None,
- tag: Annotated[SystemString64 | None, Body()] = None,
- data_tag: Annotated[SystemString64 | None, Body()] = None,
+ tag: Annotated[TagString | None, Body()] = None,
+ data_tag: Annotated[TagString | None, Body()] = None,
status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE,
task_id: Annotated[
list[Identifier] | None,
diff --git a/src/routers/types.py b/src/routers/types.py
index fcdb876a..150bbbf4 100644
--- a/src/routers/types.py
+++ b/src/routers/types.py
@@ -2,7 +2,9 @@
from pydantic import Field
-SystemString64 = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)]
+# Known as SystemString64 in the XSD
+TagString = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)]
+
CasualString128 = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)]
Identifier = Annotated[int, Field(gt=0)]
diff --git a/tests/types_test.py b/tests/types_test.py
new file mode 100644
index 00000000..5cf1efe1
--- /dev/null
+++ b/tests/types_test.py
@@ -0,0 +1,50 @@
+import string
+
+import pytest
+from pydantic import TypeAdapter, ValidationError
+
+from routers.types import Identifier, TagString
+
+_identifier = TypeAdapter(Identifier)
+
+
+def test_identifier_accepts_positive_integer() -> None:
+ assert _identifier.validate_strings("1") == 1
+
+
+def test_identifier_rejects_non_integer() -> None:
+ with pytest.raises(ValidationError):
+ _identifier.validate_strings("foo")
+
+ with pytest.raises(ValidationError):
+ _identifier.validate_strings("1.2")
+
+
+def test_identifier_rejects_negative() -> None:
+ with pytest.raises(ValidationError):
+ _identifier.validate_strings("0")
+
+
+def test_identifier_rejects_zero() -> None:
+ with pytest.raises(ValidationError):
+ _identifier.validate_strings("0")
+
+
+_tag_string = TypeAdapter(TagString)
+_valid_punctuation_tag = {"-", ".", "_"}
+_invalid_punctuation_tag = set(string.punctuation) - _valid_punctuation_tag
+
+
+def test_tag_string_pattern() -> None:
+ assert _tag_string.json_schema()["pattern"] == r"^[\w\-\.]+$"
+
+
+@pytest.mark.parametrize("tag", ["a", "c" * 64, "version2.0", "study-14", "study_15"])
+def test_tag_string_accepts_valid(tag: str) -> None:
+ assert _tag_string.validate_strings(tag) == tag
+
+
+@pytest.mark.parametrize("tag", ["", "c" * 65, *_invalid_punctuation_tag])
+def test_tag_string_rejects_invalid(tag: str) -> None:
+ with pytest.raises(ValidationError):
+ _tag_string.validate_strings(tag)
From 6bbf1f2e90eedfe6fa55f2d0d714cdad39ec97dc Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 10:03:28 +0200
Subject: [PATCH 2/7] Test is now covered by dedicated test for TagString
---
tests/routers/openml/dataset_tag_test.py | 19 -------------------
1 file changed, 19 deletions(-)
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index 5bcfb4a3..bb12a74b 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -31,25 +31,6 @@ async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.Async
assert response.status_code == HTTPStatus.UNAUTHORIZED
-@pytest.mark.parametrize(
- "tag",
- ["", "h@", " a", "a" * 65],
- ids=["too short", "@", "space", "too long"],
-)
-async def test_dataset_tag_invalid_tag_is_rejected(
- # Constraints for the tag are handled by FastAPI
- tag: str,
- py_api: httpx.AsyncClient,
-) -> None:
- response = await py_api.post(
- f"/datasets/tag?api_key={ApiKey.ADMIN}",
- json={"data_id": 1, "tag": tag},
- )
-
- assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
- assert response.json()["errors"][0]["loc"] == ["body", "tag"]
-
-
# ── Direct call tests: tag_dataset ──
From 688d8e21cc13aed71a3ffc7ababa8b01545d142f Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 10:11:15 +0200
Subject: [PATCH 3/7] Constrain types
---
src/database/datasets.py | 40 ++++++++++++++++++++++++----------------
1 file changed, 24 insertions(+), 16 deletions(-)
diff --git a/src/database/datasets.py b/src/database/datasets.py
index 664e7bdb..5f9c0e5c 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -2,7 +2,7 @@
import datetime
from collections import defaultdict
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
@@ -13,6 +13,7 @@
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
+from routers.types import Identifier, TagString
from schemas.datasets.openml import DatasetStatus, Feature
if TYPE_CHECKING:
@@ -20,7 +21,7 @@
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get(id_: int, connection: AsyncConnection) -> Row | None:
+async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
@@ -34,7 +35,7 @@ async def get(id_: int, connection: AsyncConnection) -> Row | None:
return row.one_or_none()
-async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
+async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
@@ -48,7 +49,11 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
return row.one_or_none()
-async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row | None:
+async def get_tag(
+ dataset_id: Identifier,
+ tag: TagString,
+ connection: AsyncConnection,
+) -> Row | None:
return (
await connection.execute(
text(
@@ -63,7 +68,7 @@ async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row
).first()
-async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None:
+async def delete_tag(dataset_id: Identifier, tag: TagString, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
@@ -75,7 +80,7 @@ async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) ->
)
-async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:
+async def get_tags_for(id_: Identifier, connection: AsyncConnection) -> list[str]:
row = await connection.execute(
text(
"""
@@ -115,7 +120,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
async def get_description(
- id_: int,
+ id_: Identifier,
connection: AsyncConnection,
) -> Row | None:
"""Get the most recent description for the dataset."""
@@ -133,7 +138,7 @@ async def get_description(
return row.first()
-async def get_status(id_: int, connection: AsyncConnection) -> DatasetStatus:
+async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetStatus:
"""Get most recent status for the dataset."""
row = (
await connection.execute(
@@ -152,7 +157,10 @@ async def get_status(id_: int, connection: AsyncConnection) -> DatasetStatus:
return DatasetStatus(row.status) if row else DatasetStatus.IN_PREPARATION
-async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None:
+async def get_latest_processing_update(
+ dataset_id: Identifier,
+ connection: AsyncConnection,
+) -> Row | None:
row = await connection.execute(
text(
"""
@@ -167,7 +175,7 @@ async def get_latest_processing_update(dataset_id: int, connection: AsyncConnect
return row.first()
-async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]:
+async def get_features(dataset_id: Identifier, connection: AsyncConnection) -> list[Feature]:
row = await connection.execute(
text(
"""
@@ -184,7 +192,7 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea
async def get_feature_ontologies(
- dataset_id: int,
+ dataset_id: Identifier,
connection: AsyncConnection,
) -> dict[int, list[str]]:
rows = await connection.execute(
@@ -204,7 +212,7 @@ async def get_feature_ontologies(
async def get_feature_values(
- dataset_id: int,
+ dataset_id: Identifier,
*,
feature_index: int,
connection: AsyncConnection,
@@ -224,10 +232,10 @@ async def get_feature_values(
async def update_status(
- dataset_id: int,
- status: str,
+ dataset_id: Identifier,
+ status: Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED],
*,
- user_id: int,
+ user_id: Identifier,
connection: AsyncConnection,
) -> None:
await connection.execute(
@@ -246,7 +254,7 @@ async def update_status(
)
-async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None:
+async def remove_deactivated_status(dataset_id: Identifier, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
From a7eb87b73aec317dadfb867b2cf1e032f3211e86 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 10:19:08 +0200
Subject: [PATCH 4/7] Narrow types
---
src/database/flows.py | 10 ++++++----
src/database/qualities.py | 5 +++--
src/database/runs.py | 6 ++++--
src/database/setups.py | 17 ++++++++++++-----
src/database/studies.py | 22 +++++++++++++++++-----
src/database/tasks.py | 12 +++++++-----
src/database/users.py | 17 ++++++++++++-----
7 files changed, 61 insertions(+), 28 deletions(-)
diff --git a/src/database/flows.py b/src/database/flows.py
index ed022c40..504e7821 100644
--- a/src/database/flows.py
+++ b/src/database/flows.py
@@ -3,11 +3,13 @@
from sqlalchemy import Row, text
+from routers.types import Identifier
+
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
@@ -24,7 +26,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
)
-async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
+async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
@@ -39,7 +41,7 @@ async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
return [tag.tag for tag in tag_rows]
-async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
@@ -71,7 +73,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection)
return row.one_or_none()
-async def get(id_: int, expdb: AsyncConnection) -> Row | None:
+async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
diff --git a/src/database/qualities.py b/src/database/qualities.py
index 9180b137..4899f028 100644
--- a/src/database/qualities.py
+++ b/src/database/qualities.py
@@ -4,13 +4,14 @@
from sqlalchemy import text
+from routers.types import Identifier
from schemas.datasets.openml import Quality
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]:
+async def get_for_dataset(dataset_id: Identifier, connection: AsyncConnection) -> list[Quality]:
row = await connection.execute(
text(
"""
@@ -26,7 +27,7 @@ async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[
async def get_for_datasets(
- dataset_ids: Iterable[int],
+ dataset_ids: Iterable[Identifier],
quality_names: Iterable[str],
connection: AsyncConnection,
) -> dict[int, list[Quality]]:
diff --git a/src/database/runs.py b/src/database/runs.py
index 6eef0b3b..22b0678d 100644
--- a/src/database/runs.py
+++ b/src/database/runs.py
@@ -5,11 +5,13 @@
from sqlalchemy import Row, text
+from routers.types import Identifier
+
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def exist(id_: int, expdb: AsyncConnection) -> bool:
+async def exist(id_: Identifier, expdb: AsyncConnection) -> bool:
"""Check if a run exists by ID."""
row = await expdb.execute(
text(
@@ -24,7 +26,7 @@ async def exist(id_: int, expdb: AsyncConnection) -> bool:
return bool(row.one_or_none())
-async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_trace(run_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
text(
diff --git a/src/database/setups.py b/src/database/setups.py
index c8651d93..1c959b34 100644
--- a/src/database/setups.py
+++ b/src/database/setups.py
@@ -4,12 +4,14 @@
from sqlalchemy import text
+from routers.types import Identifier, TagString
+
if TYPE_CHECKING:
from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
+async def get(setup_id: Identifier, connection: AsyncConnection) -> Row | None:
"""Get the setup with id `setup_id` from the database."""
row = await connection.execute(
text(
@@ -24,7 +26,7 @@ async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
return row.first()
-async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[RowMapping]:
+async def get_parameters(setup_id: Identifier, connection: AsyncConnection) -> list[RowMapping]:
"""Get all parameters for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
@@ -51,7 +53,7 @@ async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[Row
return list(rows.mappings().all())
-async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
+async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[Row]:
"""Get all tags for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
@@ -66,7 +68,7 @@ async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
return list(rows.all())
-async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
+async def untag(setup_id: Identifier, tag: TagString, connection: AsyncConnection) -> None:
"""Remove tag `tag` from setup with id `setup_id`."""
await connection.execute(
text(
@@ -79,7 +81,12 @@ async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
)
-async def tag(setup_id: int, tag: str, user_id: int, connection: AsyncConnection) -> None:
+async def tag(
+ setup_id: Identifier,
+ tag: TagString,
+ user_id: Identifier,
+ connection: AsyncConnection,
+) -> None:
"""Add tag `tag` to setup with id `setup_id`."""
await connection.execute(
text(
diff --git a/src/database/studies.py b/src/database/studies.py
index 3c7c8e96..19394526 100644
--- a/src/database/studies.py
+++ b/src/database/studies.py
@@ -6,13 +6,14 @@
from sqlalchemy import Row, text
from database.users import User
+from routers.types import Identifier
from schemas.study import CreateStudy, StudyType
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None:
+async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
@@ -114,7 +115,12 @@ async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int:
return cast("int", study_id)
-async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None:
+async def attach_task(
+ task_id: Identifier,
+ study_id: Identifier,
+ user: User,
+ expdb: AsyncConnection,
+) -> None:
await expdb.execute(
text(
"""
@@ -126,7 +132,13 @@ async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConne
)
-async def attach_run(*, run_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None:
+async def attach_run(
+ *,
+ run_id: Identifier,
+ study_id: Identifier,
+ user: User,
+ expdb: AsyncConnection,
+) -> None:
await expdb.execute(
text(
"""
@@ -171,8 +183,8 @@ async def attach_tasks(
async def attach_runs(
- study_id: int,
- run_ids: list[int],
+ study_id: Identifier,
+ run_ids: list[Identifier],
user: User,
connection: AsyncConnection,
) -> None:
diff --git a/src/database/tasks.py b/src/database/tasks.py
index ba4eeccb..39fadb01 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -3,11 +3,13 @@
from sqlalchemy import Row, text
+from routers.types import Identifier
+
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get(id_: int, expdb: AsyncConnection) -> Row | None:
+async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
@@ -36,7 +38,7 @@ async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]:
)
-async def get_task_type(task_type_id: int, expdb: AsyncConnection) -> Row | None:
+async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
@@ -67,7 +69,7 @@ async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) ->
)
-async def get_input_for_task(id_: int, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
@@ -85,7 +87,7 @@ async def get_input_for_task(id_: int, expdb: AsyncConnection) -> Sequence[Row]:
async def get_task_type_inout_with_template(
- task_type: int,
+ task_type: Identifier,
expdb: AsyncConnection,
) -> Sequence[Row]:
rows = await expdb.execute(
@@ -104,7 +106,7 @@ async def get_task_type_inout_with_template(
)
-async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
+async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
diff --git a/src/database/users.py b/src/database/users.py
index bc2d645b..beeb14e6 100644
--- a/src/database/users.py
+++ b/src/database/users.py
@@ -8,6 +8,7 @@
from sqlalchemy import text
from config import get_config
+from routers.types import Identifier
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
@@ -57,7 +58,11 @@ async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> in
return user.id if user else None
-async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[int]:
+async def get_user_groups_for(
+ *,
+ user_id: Identifier,
+ connection: AsyncConnection,
+) -> list[UserGroup]:
row = await connection.execute(
text(
"""
@@ -69,12 +74,12 @@ async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> l
parameters={"user_id": user_id},
)
rows = row.all()
- return [group for (group,) in rows]
+ return [UserGroup(group) for (group,) in rows]
@dataclasses.dataclass
class User:
- user_id: int
+ user_id: Identifier
_database: AsyncConnection
_groups: list[UserGroup] | None = None
@@ -86,8 +91,10 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None:
async def get_groups(self) -> list[UserGroup]:
if self._groups is None:
- group_ids = await get_user_groups_for(user_id=self.user_id, connection=self._database)
- self._groups = [UserGroup(group_id) for group_id in group_ids]
+ self._groups = await get_user_groups_for(
+ user_id=self.user_id,
+ connection=self._database,
+ )
return self._groups
async def is_admin(self) -> bool:
From 53029d46e02d50db0b4834e43bee8b66ec730eab Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 11:23:41 +0200
Subject: [PATCH 5/7] write tests for CasualString
---
src/routers/types.py | 5 ++++-
tests/types_test.py | 26 +++++++++++++++++++++++---
2 files changed, 27 insertions(+), 4 deletions(-)
diff --git a/src/routers/types.py b/src/routers/types.py
index 150bbbf4..fe945778 100644
--- a/src/routers/types.py
+++ b/src/routers/types.py
@@ -5,7 +5,10 @@
# Known as SystemString64 in the XSD
TagString = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)]
-CasualString128 = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)]
+# Currently used for a variety of fields, like `name` or `feature`.
+CasualString = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1)]
+CasualString128 = Annotated[CasualString, Field(max_length=128)]
+
Identifier = Annotated[int, Field(gt=0)]
integer_range_regex = r"^(\d+)(\.\.\d+)?$"
diff --git a/tests/types_test.py b/tests/types_test.py
index 5cf1efe1..e9a046bc 100644
--- a/tests/types_test.py
+++ b/tests/types_test.py
@@ -3,7 +3,7 @@
import pytest
from pydantic import TypeAdapter, ValidationError
-from routers.types import Identifier, TagString
+from routers.types import CasualString, Identifier, TagString
_identifier = TypeAdapter(Identifier)
@@ -31,7 +31,7 @@ def test_identifier_rejects_zero() -> None:
_tag_string = TypeAdapter(TagString)
-_valid_punctuation_tag = {"-", ".", "_"}
+_valid_punctuation_tag = set("_-.")
_invalid_punctuation_tag = set(string.punctuation) - _valid_punctuation_tag
@@ -44,7 +44,27 @@ def test_tag_string_accepts_valid(tag: str) -> None:
assert _tag_string.validate_strings(tag) == tag
-@pytest.mark.parametrize("tag", ["", "c" * 65, *_invalid_punctuation_tag])
+@pytest.mark.parametrize("tag", ["", " ", "c" * 65, *_invalid_punctuation_tag])
def test_tag_string_rejects_invalid(tag: str) -> None:
with pytest.raises(ValidationError):
_tag_string.validate_strings(tag)
+
+
+_casual_string = TypeAdapter(CasualString)
+_valid_punctuation_casual_string = set("_-.(),")
+_invalid_punctuation_casual_string = set(string.punctuation) - _valid_punctuation_casual_string
+
+
+def test_casual_string_pattern() -> None:
+ assert _casual_string.json_schema()["pattern"] == r"^[\w\-\.\(\),]+$"
+
+
+@pytest.mark.parametrize("string", ["a", "a" * 1000, "_-.(),"])
+def test_casual_string_accepts_valid(string: str) -> None:
+ assert _casual_string.validate_strings(string)
+
+
+@pytest.mark.parametrize("string", ["", *_invalid_punctuation_casual_string])
+def test_casual_string_rejects_invalid(string: str) -> None:
+ with pytest.raises(ValidationError):
+ _casual_string.validate_strings(string)
From 7bf8fdcb1f58a4154a97b830d5692934c8b89a15 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 11:41:08 +0200
Subject: [PATCH 6/7] Fixed ordering of parameters for pytest-xdist
compatibility
---
tests/types_test.py | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/tests/types_test.py b/tests/types_test.py
index e9a046bc..8e2dadc7 100644
--- a/tests/types_test.py
+++ b/tests/types_test.py
@@ -1,3 +1,11 @@
+"""Tests validation of custom types.
+
+Note that for parametrized tests, it is important that the value order and
+amount must be consistent to allow distribution with pytest-xdist:
+https://pytest-xdist.readthedocs.io/en/latest/known-limitations.html#order-and-amount-of-test-must-be-consistent
+
+"""
+
import string
import pytest
@@ -31,8 +39,8 @@ def test_identifier_rejects_zero() -> None:
_tag_string = TypeAdapter(TagString)
-_valid_punctuation_tag = set("_-.")
-_invalid_punctuation_tag = set(string.punctuation) - _valid_punctuation_tag
+_valid_punctuation_tag = list("_-.")
+_invalid_punctuation_tag = sorted(set(string.punctuation) - set(_valid_punctuation_tag))
def test_tag_string_pattern() -> None:
@@ -51,8 +59,10 @@ def test_tag_string_rejects_invalid(tag: str) -> None:
_casual_string = TypeAdapter(CasualString)
-_valid_punctuation_casual_string = set("_-.(),")
-_invalid_punctuation_casual_string = set(string.punctuation) - _valid_punctuation_casual_string
+_valid_punctuation_casual_string = list(set("_-.(),"))
+_invalid_punctuation_casual_string = sorted(
+ set(string.punctuation) - set(_valid_punctuation_casual_string)
+)
def test_casual_string_pattern() -> None:
From 6cd637db7fa5a3e9bbaeb5d5bf31bf07619d44af Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 8 May 2026 11:46:51 +0200
Subject: [PATCH 7/7] fix typo and expand test coverage for whitespace in input
---
tests/types_test.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/tests/types_test.py b/tests/types_test.py
index 8e2dadc7..7a4df612 100644
--- a/tests/types_test.py
+++ b/tests/types_test.py
@@ -30,7 +30,7 @@ def test_identifier_rejects_non_integer() -> None:
def test_identifier_rejects_negative() -> None:
with pytest.raises(ValidationError):
- _identifier.validate_strings("0")
+ _identifier.validate_strings("-1")
def test_identifier_rejects_zero() -> None:
@@ -52,7 +52,9 @@ def test_tag_string_accepts_valid(tag: str) -> None:
assert _tag_string.validate_strings(tag) == tag
-@pytest.mark.parametrize("tag", ["", " ", "c" * 65, *_invalid_punctuation_tag])
+@pytest.mark.parametrize(
+ "tag", ["", " ", "a ", " a", "a b", "a\t", "c" * 65, *_invalid_punctuation_tag]
+)
def test_tag_string_rejects_invalid(tag: str) -> None:
with pytest.raises(ValidationError):
_tag_string.validate_strings(tag)
@@ -74,7 +76,9 @@ def test_casual_string_accepts_valid(string: str) -> None:
assert _casual_string.validate_strings(string)
-@pytest.mark.parametrize("string", ["", *_invalid_punctuation_casual_string])
+@pytest.mark.parametrize(
+ "string", ["", " ", "a ", " a", "a b", "a\t", *_invalid_punctuation_casual_string]
+)
def test_casual_string_rejects_invalid(string: str) -> None:
with pytest.raises(ValidationError):
_casual_string.validate_strings(string)