diff --git a/src/database/datasets.py b/src/database/datasets.py index 664e7bdb..5f9c0e5c 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,7 +2,7 @@ import datetime from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from sqlalchemy import text from sqlalchemy.exc import IntegrityError @@ -13,6 +13,7 @@ DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from routers.types import Identifier, TagString from schemas.datasets.openml import DatasetStatus, Feature if TYPE_CHECKING: @@ -20,7 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection -async def get(id_: int, connection: AsyncConnection) -> Row | None: +async def get(id_: Identifier, connection: AsyncConnection) -> Row | None: row = await connection.execute( text( """ @@ -34,7 +35,7 @@ async def get(id_: int, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: +async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None: row = await connection.execute( text( """ @@ -48,7 +49,11 @@ async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row | None: +async def get_tag( + dataset_id: Identifier, + tag: TagString, + connection: AsyncConnection, +) -> Row | None: return ( await connection.execute( text( @@ -63,7 +68,7 @@ async def get_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> Row ).first() -async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> None: +async def delete_tag(dataset_id: Identifier, tag: TagString, connection: AsyncConnection) -> None: await connection.execute( text( """ @@ -75,7 +80,7 @@ async def delete_tag(dataset_id: int, tag: str, connection: AsyncConnection) -> ) -async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]: +async def get_tags_for(id_: Identifier, connection: AsyncConnection) -> list[str]: row = await connection.execute( text( """ @@ -115,7 +120,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) async def get_description( - id_: int, + id_: Identifier, connection: AsyncConnection, ) -> Row | None: """Get the most recent description for the dataset.""" @@ -133,7 +138,7 @@ async def get_description( return row.first() -async def get_status(id_: int, connection: AsyncConnection) -> DatasetStatus: +async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetStatus: """Get most recent status for the dataset.""" row = ( await connection.execute( @@ -152,7 +157,10 @@ async def get_status(id_: int, connection: AsyncConnection) -> DatasetStatus: return DatasetStatus(row.status) if row else DatasetStatus.IN_PREPARATION -async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None: +async def get_latest_processing_update( + dataset_id: Identifier, + connection: AsyncConnection, +) -> Row | None: row = await connection.execute( text( """ @@ -167,7 +175,7 @@ async def get_latest_processing_update(dataset_id: int, connection: AsyncConnect return row.first() -async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]: +async def get_features(dataset_id: Identifier, connection: AsyncConnection) -> list[Feature]: row = await connection.execute( text( """ @@ -184,7 +192,7 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea async def get_feature_ontologies( - dataset_id: int, + dataset_id: Identifier, connection: AsyncConnection, ) -> dict[int, list[str]]: rows = await connection.execute( @@ -204,7 +212,7 @@ async def get_feature_ontologies( async def get_feature_values( - dataset_id: int, + dataset_id: Identifier, *, feature_index: int, connection: AsyncConnection, @@ -224,10 +232,10 @@ async def get_feature_values( async def update_status( - dataset_id: int, - status: str, + dataset_id: Identifier, + status: Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], *, - user_id: int, + user_id: Identifier, connection: AsyncConnection, ) -> None: await connection.execute( @@ -246,7 +254,7 @@ async def update_status( ) -async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None: +async def remove_deactivated_status(dataset_id: Identifier, connection: AsyncConnection) -> None: await connection.execute( text( """ diff --git a/src/database/flows.py b/src/database/flows.py index ed022c40..504e7821 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -3,11 +3,13 @@ from sqlalchemy import Row, text +from routers.types import Identifier + if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]: rows = await expdb.execute( text( """ @@ -24,7 +26,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: ) -async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]: +async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]: rows = await expdb.execute( text( """ @@ -39,7 +41,7 @@ async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]: return [tag.tag for tag in tag_rows] -async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]: rows = await expdb.execute( text( """ @@ -71,7 +73,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) return row.one_or_none() -async def get(id_: int, expdb: AsyncConnection) -> Row | None: +async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( text( """ diff --git a/src/database/qualities.py b/src/database/qualities.py index 9180b137..4899f028 100644 --- a/src/database/qualities.py +++ b/src/database/qualities.py @@ -4,13 +4,14 @@ from sqlalchemy import text +from routers.types import Identifier from schemas.datasets.openml import Quality if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]: +async def get_for_dataset(dataset_id: Identifier, connection: AsyncConnection) -> list[Quality]: row = await connection.execute( text( """ @@ -26,7 +27,7 @@ async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[ async def get_for_datasets( - dataset_ids: Iterable[int], + dataset_ids: Iterable[Identifier], quality_names: Iterable[str], connection: AsyncConnection, ) -> dict[int, list[Quality]]: diff --git a/src/database/runs.py b/src/database/runs.py index 6eef0b3b..22b0678d 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -5,11 +5,13 @@ from sqlalchemy import Row, text +from routers.types import Identifier + if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def exist(id_: int, expdb: AsyncConnection) -> bool: +async def exist(id_: Identifier, expdb: AsyncConnection) -> bool: """Check if a run exists by ID.""" row = await expdb.execute( text( @@ -24,7 +26,7 @@ async def exist(id_: int, expdb: AsyncConnection) -> bool: return bool(row.one_or_none()) -async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_trace(run_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]: """Get trace rows for a run from the trace table.""" rows = await expdb.execute( text( diff --git a/src/database/setups.py b/src/database/setups.py index c8651d93..1c959b34 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -4,12 +4,14 @@ from sqlalchemy import text +from routers.types import Identifier, TagString + if TYPE_CHECKING: from sqlalchemy.engine import Row, RowMapping from sqlalchemy.ext.asyncio import AsyncConnection -async def get(setup_id: int, connection: AsyncConnection) -> Row | None: +async def get(setup_id: Identifier, connection: AsyncConnection) -> Row | None: """Get the setup with id `setup_id` from the database.""" row = await connection.execute( text( @@ -24,7 +26,7 @@ async def get(setup_id: int, connection: AsyncConnection) -> Row | None: return row.first() -async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[RowMapping]: +async def get_parameters(setup_id: Identifier, connection: AsyncConnection) -> list[RowMapping]: """Get all parameters for setup with `setup_id` from the database.""" rows = await connection.execute( text( @@ -51,7 +53,7 @@ async def get_parameters(setup_id: int, connection: AsyncConnection) -> list[Row return list(rows.mappings().all()) -async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]: +async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[Row]: """Get all tags for setup with `setup_id` from the database.""" rows = await connection.execute( text( @@ -66,7 +68,7 @@ async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]: return list(rows.all()) -async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None: +async def untag(setup_id: Identifier, tag: TagString, connection: AsyncConnection) -> None: """Remove tag `tag` from setup with id `setup_id`.""" await connection.execute( text( @@ -79,7 +81,12 @@ async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None: ) -async def tag(setup_id: int, tag: str, user_id: int, connection: AsyncConnection) -> None: +async def tag( + setup_id: Identifier, + tag: TagString, + user_id: Identifier, + connection: AsyncConnection, +) -> None: """Add tag `tag` to setup with id `setup_id`.""" await connection.execute( text( diff --git a/src/database/studies.py b/src/database/studies.py index 3c7c8e96..19394526 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -6,13 +6,14 @@ from sqlalchemy import Row, text from database.users import User +from routers.types import Identifier from schemas.study import CreateStudy, StudyType if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None: +async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None: row = await connection.execute( text( """ @@ -114,7 +115,12 @@ async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int: return cast("int", study_id) -async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None: +async def attach_task( + task_id: Identifier, + study_id: Identifier, + user: User, + expdb: AsyncConnection, +) -> None: await expdb.execute( text( """ @@ -126,7 +132,13 @@ async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConne ) -async def attach_run(*, run_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None: +async def attach_run( + *, + run_id: Identifier, + study_id: Identifier, + user: User, + expdb: AsyncConnection, +) -> None: await expdb.execute( text( """ @@ -171,8 +183,8 @@ async def attach_tasks( async def attach_runs( - study_id: int, - run_ids: list[int], + study_id: Identifier, + run_ids: list[Identifier], user: User, connection: AsyncConnection, ) -> None: diff --git a/src/database/tasks.py b/src/database/tasks.py index ba4eeccb..39fadb01 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -3,11 +3,13 @@ from sqlalchemy import Row, text +from routers.types import Identifier + if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get(id_: int, expdb: AsyncConnection) -> Row | None: +async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( text( """ @@ -36,7 +38,7 @@ async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: ) -async def get_task_type(task_type_id: int, expdb: AsyncConnection) -> Row | None: +async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( text( """ @@ -67,7 +69,7 @@ async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> ) -async def get_input_for_task(id_: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[Row]: rows = await expdb.execute( text( """ @@ -85,7 +87,7 @@ async def get_input_for_task(id_: int, expdb: AsyncConnection) -> Sequence[Row]: async def get_task_type_inout_with_template( - task_type: int, + task_type: Identifier, expdb: AsyncConnection, ) -> Sequence[Row]: rows = await expdb.execute( @@ -104,7 +106,7 @@ async def get_task_type_inout_with_template( ) -async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: +async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]: rows = await expdb.execute( text( """ diff --git a/src/database/users.py b/src/database/users.py index bc2d645b..beeb14e6 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -8,6 +8,7 @@ from sqlalchemy import text from config import get_config +from routers.types import Identifier if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection @@ -57,7 +58,11 @@ async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> in return user.id if user else None -async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[int]: +async def get_user_groups_for( + *, + user_id: Identifier, + connection: AsyncConnection, +) -> list[UserGroup]: row = await connection.execute( text( """ @@ -69,12 +74,12 @@ async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> l parameters={"user_id": user_id}, ) rows = row.all() - return [group for (group,) in rows] + return [UserGroup(group) for (group,) in rows] @dataclasses.dataclass class User: - user_id: int + user_id: Identifier _database: AsyncConnection _groups: list[UserGroup] | None = None @@ -86,8 +91,10 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: async def get_groups(self) -> list[UserGroup]: if self._groups is None: - group_ids = await get_user_groups_for(user_id=self.user_id, connection=self._database) - self._groups = [UserGroup(group_id) for group_id in group_ids] + self._groups = await get_user_groups_for( + user_id=self.user_id, + connection=self._database, + ) return self._groups async def is_admin(self) -> bool: diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 045f4e50..270c71b5 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -46,7 +46,7 @@ CasualString128, Identifier, IntegerRange, - SystemString64, + TagString, integer_range_regex, ) from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType @@ -63,7 +63,7 @@ ) async def tag_dataset( data_id: Annotated[Identifier, Body()], - tag: Annotated[SystemString64, Body()], + tag: Annotated[TagString, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, Any]]: @@ -87,13 +87,13 @@ async def tag_dataset( class TagInfo(TypedDict): id: str - tag: NotRequired[SystemString64 | list[SystemString64]] + tag: NotRequired[TagString | list[TagString]] @router.post(path="/untag", deprecated=True) async def untag_dataset_like_php( data_id: Annotated[Identifier, Body()], - tag: Annotated[SystemString64, Body()], + tag: Annotated[TagString, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["data_untag"], TagInfo]: @@ -110,7 +110,7 @@ async def untag_dataset_like_php( @router.delete(path="/{identifier}/tag", status_code=HTTPStatus.NO_CONTENT) async def untag_dataset( identifier: Identifier, - tag: Annotated[SystemString64, Query()], + tag: Annotated[TagString, Query()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> None: @@ -158,7 +158,7 @@ def _quality_clause(quality: str, range_: str | None) -> str: async def list_datasets( # noqa: PLR0913, C901 pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[CasualString128 | None, Body()] = None, - tag: Annotated[SystemString64 | None, Body()] = None, + tag: Annotated[TagString | None, Body()] = None, data_version: Annotated[ Identifier | None, Body(description="The dataset version to include in the search."), diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py index dbca9038..ef71ce61 100644 --- a/src/routers/openml/setups.py +++ b/src/routers/openml/setups.py @@ -15,7 +15,7 @@ ) from database.users import User from routers.dependencies import expdb_connection, fetch_user_or_raise -from routers.types import Identifier, SystemString64 +from routers.types import Identifier, TagString from schemas.setups import SetupParameters, SetupResponse if TYPE_CHECKING: @@ -49,7 +49,7 @@ async def get_setup( @router.post(path="/tag") async def tag_setup( setup_id: Annotated[Identifier, Body()], - tag: Annotated[SystemString64, Body()], + tag: Annotated[TagString, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: @@ -76,7 +76,7 @@ async def tag_setup( @router.post(path="/untag") async def untag_setup( setup_id: Annotated[Identifier, Body()], - tag: Annotated[SystemString64, Body()], + tag: Annotated[TagString, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 411ea5b5..8faa898c 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -17,7 +17,7 @@ CasualString128, Identifier, IntegerRange, - SystemString64, + TagString, integer_range_regex, ) from schemas.datasets.openml import Task @@ -231,8 +231,8 @@ def _quality_clause(quality: str, range_: str | None) -> str: async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 pagination: Annotated[Pagination, Body(default_factory=Pagination)], task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None, - tag: Annotated[SystemString64 | None, Body()] = None, - data_tag: Annotated[SystemString64 | None, Body()] = None, + tag: Annotated[TagString | None, Body()] = None, + data_tag: Annotated[TagString | None, Body()] = None, status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, task_id: Annotated[ list[Identifier] | None, diff --git a/src/routers/types.py b/src/routers/types.py index fcdb876a..fe945778 100644 --- a/src/routers/types.py +++ b/src/routers/types.py @@ -2,8 +2,13 @@ from pydantic import Field -SystemString64 = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)] -CasualString128 = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)] +# Known as SystemString64 in the XSD +TagString = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)] + +# Currently used for a variety of fields, like `name` or `feature`. +CasualString = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1)] +CasualString128 = Annotated[CasualString, Field(max_length=128)] + Identifier = Annotated[int, Field(gt=0)] integer_range_regex = r"^(\d+)(\.\.\d+)?$" diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 5bcfb4a3..bb12a74b 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -31,25 +31,6 @@ async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.Async assert response.status_code == HTTPStatus.UNAUTHORIZED -@pytest.mark.parametrize( - "tag", - ["", "h@", " a", "a" * 65], - ids=["too short", "@", "space", "too long"], -) -async def test_dataset_tag_invalid_tag_is_rejected( - # Constraints for the tag are handled by FastAPI - tag: str, - py_api: httpx.AsyncClient, -) -> None: - response = await py_api.post( - f"/datasets/tag?api_key={ApiKey.ADMIN}", - json={"data_id": 1, "tag": tag}, - ) - - assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY - assert response.json()["errors"][0]["loc"] == ["body", "tag"] - - # ── Direct call tests: tag_dataset ── diff --git a/tests/types_test.py b/tests/types_test.py new file mode 100644 index 00000000..7a4df612 --- /dev/null +++ b/tests/types_test.py @@ -0,0 +1,84 @@ +"""Tests validation of custom types. + +Note that for parametrized tests, it is important that the value order and +amount must be consistent to allow distribution with pytest-xdist: +https://pytest-xdist.readthedocs.io/en/latest/known-limitations.html#order-and-amount-of-test-must-be-consistent + +""" + +import string + +import pytest +from pydantic import TypeAdapter, ValidationError + +from routers.types import CasualString, Identifier, TagString + +_identifier = TypeAdapter(Identifier) + + +def test_identifier_accepts_positive_integer() -> None: + assert _identifier.validate_strings("1") == 1 + + +def test_identifier_rejects_non_integer() -> None: + with pytest.raises(ValidationError): + _identifier.validate_strings("foo") + + with pytest.raises(ValidationError): + _identifier.validate_strings("1.2") + + +def test_identifier_rejects_negative() -> None: + with pytest.raises(ValidationError): + _identifier.validate_strings("-1") + + +def test_identifier_rejects_zero() -> None: + with pytest.raises(ValidationError): + _identifier.validate_strings("0") + + +_tag_string = TypeAdapter(TagString) +_valid_punctuation_tag = list("_-.") +_invalid_punctuation_tag = sorted(set(string.punctuation) - set(_valid_punctuation_tag)) + + +def test_tag_string_pattern() -> None: + assert _tag_string.json_schema()["pattern"] == r"^[\w\-\.]+$" + + +@pytest.mark.parametrize("tag", ["a", "c" * 64, "version2.0", "study-14", "study_15"]) +def test_tag_string_accepts_valid(tag: str) -> None: + assert _tag_string.validate_strings(tag) == tag + + +@pytest.mark.parametrize( + "tag", ["", " ", "a ", " a", "a b", "a\t", "c" * 65, *_invalid_punctuation_tag] +) +def test_tag_string_rejects_invalid(tag: str) -> None: + with pytest.raises(ValidationError): + _tag_string.validate_strings(tag) + + +_casual_string = TypeAdapter(CasualString) +_valid_punctuation_casual_string = list(set("_-.(),")) +_invalid_punctuation_casual_string = sorted( + set(string.punctuation) - set(_valid_punctuation_casual_string) +) + + +def test_casual_string_pattern() -> None: + assert _casual_string.json_schema()["pattern"] == r"^[\w\-\.\(\),]+$" + + +@pytest.mark.parametrize("string", ["a", "a" * 1000, "_-.(),"]) +def test_casual_string_accepts_valid(string: str) -> None: + assert _casual_string.validate_strings(string) + + +@pytest.mark.parametrize( + "string", ["", " ", "a ", " a", "a b", "a\t", *_invalid_punctuation_casual_string] +) +def test_casual_string_rejects_invalid(string: str) -> None: + with pytest.raises(ValidationError): + _casual_string.validate_strings(string)