diff --git a/src/cattrs/strategies/__init__.py b/src/cattrs/strategies/__init__.py index 9caf0732..40598a42 100644 --- a/src/cattrs/strategies/__init__.py +++ b/src/cattrs/strategies/__init__.py @@ -2,11 +2,12 @@ from ._class_methods import use_class_methods from ._subclasses import include_subclasses -from ._unions import configure_tagged_union, configure_union_passthrough +from ._unions import configure_tagged_union, configure_union_passthrough, configure_union_single_collection_dispatch __all__ = [ "configure_tagged_union", "configure_union_passthrough", + "configure_union_single_collection_dispatch", "include_subclasses", "use_class_methods", ] diff --git a/src/cattrs/strategies/_unions.py b/src/cattrs/strategies/_unions.py index 14add983..63b619e0 100644 --- a/src/cattrs/strategies/_unions.py +++ b/src/cattrs/strategies/_unions.py @@ -1,5 +1,6 @@ -from collections import defaultdict -from typing import Any, Callable, Union +import collections.abc +from collections import defaultdict, deque +from typing import Any, Callable, Union, get_origin from attrs import NOTHING, NothingType @@ -282,3 +283,74 @@ def contains_native_union(exact_type: Any) -> bool: converter.register_structure_hook_factory( contains_native_union, make_structure_native_union ) + + +# Design choice: it was easy to extend the logic to deque and set types +# but are they worth adding? +_COLLECTION_TYPES = frozenset([ + collections.abc.MutableSequence, + collections.abc.MutableSet, + collections.abc.Sequence, + collections.abc.Set, + deque, + frozenset, + list, + set, + tuple, +]) + +def configure_union_single_collection_dispatch(converter: BaseConverter): + def is_union_single_collection(exact_type: Any) -> bool: + # TODO: Handle TypeAliasType (see #742) + + if not is_union_type(exact_type): + return False + + type_args = set(exact_type.__args__) + if len(type_args) == 2 and type(None) in type_args: + # As in union_passthrough, we do not want to handle optionals + return False + + # Design choice: only support the case where one of _COLLECTION_TYPES + # appears in the Union + collection_type_args = [ + t + for t in type_args + if t in _COLLECTION_TYPES or get_origin(t) in _COLLECTION_TYPES + ] + return len(collection_type_args) == 1 + + def make_structure_union_single_collection( + exact_type: Any, / + ) -> Callable[[Any, Any], Any]: + # TODO: Handle TypeAliasType (see #742) + + type_args = set(exact_type.__args__) + collection_type_arg = next( + t + for t in type_args + if t in _COLLECTION_TYPES or get_origin(t) in _COLLECTION_TYPES + ) + + other_type_args = [t for t in type_args if t != collection_type_arg] + spillover_type: Any = ( + Union[tuple(other_type_args)] + if len(other_type_args) > 1 + else other_type_args[0] + ) + + def structure_union_single_collection( + val: Any, + _: Any, + collection_type=collection_type_arg, + spillover=spillover_type, + ) -> Any: + # Design choice: only detect known concrete types as valid source types + # That avoids having to blacklist e.g. str or bytes + if isinstance(val, (deque, frozenset, list, set, tuple)): + return converter.structure(val, collection_type) + return converter.structure(val, spillover) + + return structure_union_single_collection + + converter.register_structure_hook_factory(is_union_single_collection, make_structure_union_single_collection) diff --git a/tests/strategies/test_union_single_collection_dispatch.py b/tests/strategies/test_union_single_collection_dispatch.py new file mode 100644 index 00000000..6f974dcc --- /dev/null +++ b/tests/strategies/test_union_single_collection_dispatch.py @@ -0,0 +1,85 @@ + +from collections import deque +from collections.abc import Callable, Collection, Iterable, MutableSequence, MutableSet, Sequence, Set +from typing import Any, Union + +import pytest + +from attrs import define +from cattrs.converters import BaseConverter +from cattrs.strategies import configure_union_single_collection_dispatch + + +@define +class CollectionParameter: + type_factory: Callable[[Any], type[Collection]] + factory: Callable[[Iterable], Collection] + + +@pytest.fixture( + params=[ + pytest.param(CollectionParameter(lambda t: deque[t], deque), id="deque"), + pytest.param(CollectionParameter(lambda t: frozenset[t], frozenset), id="frozenset"), + pytest.param(CollectionParameter(lambda t: list[t], list), id="list"), + pytest.param(CollectionParameter(lambda t: MutableSequence[t], list), id="MutableSequence"), + pytest.param(CollectionParameter(lambda t: MutableSet[t], set), id="MutableSet"), + pytest.param(CollectionParameter(lambda t: Sequence[t], tuple), id="Sequence"), + pytest.param(CollectionParameter(lambda t: Set[t], frozenset), id="Set"), + pytest.param(CollectionParameter(lambda t: set[t], set), id="set"), + pytest.param(CollectionParameter(lambda t: tuple[t, ...], tuple), id="tuple"), + ], +) +def collection(request: pytest.FixtureRequest) -> CollectionParameter: + return request.param + + +def test_works_with_simple_union(converter: BaseConverter, collection: CollectionParameter): + configure_union_single_collection_dispatch(converter) + + union = Union[collection.type_factory(str) | str] + + assert converter.structure("abcd", union) == "abcd" + assert converter.structure("abcd", str) == "abcd" + + + expected_structured = collection.factory(["abcd"]) + assert converter.structure(["abcd"], union) == expected_structured + assert converter.structure(["abcd"], collection.type_factory(str)) == expected_structured + assert converter.structure(deque(["abcd"]), union) == expected_structured + assert converter.structure(deque(["abcd"]), collection.type_factory(str)) == expected_structured + assert converter.structure(frozenset(["abcd"]), union) == expected_structured + assert converter.structure(frozenset(["abcd"]), collection.type_factory(str)) == expected_structured + assert converter.structure(set(["abcd"]), union) == expected_structured + assert converter.structure(set(["abcd"]), collection.type_factory(str)) == expected_structured + assert converter.structure(tuple(["abcd"]), union) == expected_structured + assert converter.structure(tuple(["abcd"]), collection.type_factory(str)) == expected_structured + + +def test_apply_union_disambiguation(converter: BaseConverter, collection: CollectionParameter): + configure_union_single_collection_dispatch(converter) + + @define(frozen=True) + class A: + a: int + + @define(frozen=True) + class B: + b: int + + collection_type = collection.type_factory(Union[A, B]) + union = Union[collection_type, A, B] + + assert converter.structure({"a": 1}, union) == A(1) + assert converter.structure({"a": 1}, Union[A, B]) == A(1) + assert converter.structure({"a": 1}, A) == A(1) + assert converter.structure({"b": 2}, union) == B(2) + assert converter.structure({"b": 2}, Union[A, B]) == B(2) + assert converter.structure({"b": 2}, B) == B(2) + + expected_structured = collection.factory([A(1), B(2)]) + assert converter.structure([{"a": 1}, {"b": 2}], union) == expected_structured + assert converter.structure([{"a": 1}, {"b": 2}], collection_type) == expected_structured + assert converter.structure(deque([{"a": 1}, {"b": 2}]), union) == expected_structured + assert converter.structure(deque([{"a": 1}, {"b": 2}]), collection_type) == expected_structured + assert converter.structure(tuple([{"a": 1}, {"b": 2}]), union) == expected_structured + assert converter.structure(tuple([{"a": 1}, {"b": 2}]), collection_type) == expected_structured