diff --git a/Justfile b/Justfile index 7bab81d5..1c25f065 100644 --- a/Justfile +++ b/Justfile @@ -8,12 +8,14 @@ install: lint: uv run ruff format uv run ruff check --fix - uv run mypy . + uv run mypy . --disable-error-code=unused-ignore + uv run pyrefly check --no-progress-bar lint-ci: uv run ruff format --check uv run ruff check --no-fix - uv run mypy . + uv run mypy . --disable-error-code=unused-ignore + uv run pyrefly check --no-progress-bar test *args: uv run --no-sync pytest {{ args }} diff --git a/docs/migration/v4.md b/docs/migration/v4.md new file mode 100644 index 00000000..65cb9321 --- /dev/null +++ b/docs/migration/v4.md @@ -0,0 +1,92 @@ +# Migrating from 3.\* to 4.\* + +## How to Read This Guide + +This guide is intended to help you migrate existing functionality from `that-depends` version `3.*` to `4.*`. +The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase. + +This migration intentionally focuses only on the **simplest change needed to preserve existing behaviour**. + +If you want to learn more about the new internals introduced in `4.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases). + +--- + +## Changes in the API + +### **Collection providers now return read-only container types** + +In `4.*`, collection providers no longer resolve to mutable built-in containers: + +- `providers.List(...)` now resolves to a `Sequence` implemented as a `tuple` +- `providers.Dict(...)` now resolves to a `Mapping` implemented as a `mappingproxy` + +If your existing code only reads from these values, you likely do not need to change anything. + +If your existing code mutates the resolved value, the **simplest way to preserve the old behaviour** is to convert the result at the call site: + +```python +items = list(MyContainer.items.resolve_sync()) +mapping = dict(MyContainer.mapping.resolve_sync()) +``` + +--- + +## Behaviour-Preserving Migration Examples + +### **`providers.List(...)`** + +Previously in `3.*`, code like this returned a `list`: + +```python +items = MyContainer.items.resolve_sync() +items.append("new-item") +``` + +In `4.*`, `items` is a read-only sequence, so mutating it directly will no longer work. + +To preserve the previous behaviour, change it to: + +```python +items = list(MyContainer.items.resolve_sync()) +items.append("new-item") +``` + +The same applies to async resolution: + +```python +items = list(await MyContainer.items.resolve()) +items.append("new-item") +``` + +--- + +### **`providers.Dict(...)`** + +Previously in `3.*`, code like this returned a mutable `dict`: + +```python +mapping = MyContainer.mapping.resolve_sync() +mapping["extra"] = "value" +``` + +In `4.*`, `mapping` is a read-only mapping, so item assignment will no longer work. + +To preserve the previous behaviour, change it to: + +```python +mapping = dict(MyContainer.mapping.resolve_sync()) +mapping["extra"] = "value" +``` + +The same applies to async resolution: + +```python +mapping = dict(await MyContainer.mapping.resolve()) +mapping["extra"] = "value" +``` + +--- + +## Further Help + +If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues). diff --git a/docs/overrides/main.html b/docs/overrides/main.html index 7496f3a4..1f99571a 100644 --- a/docs/overrides/main.html +++ b/docs/overrides/main.html @@ -1,5 +1,5 @@ {% extends "base.html" %} {% block announce %} - that-depends 3.0 just released! If you are upgrading, check out the migration guide. + that-depends 4.0 just released! If you are upgrading, check out the migration guide. {% endblock %} diff --git a/docs/providers/collections.md b/docs/providers/collections.md index 34294ba9..86154416 100644 --- a/docs/providers/collections.md +++ b/docs/providers/collections.md @@ -3,7 +3,7 @@ There are several collection providers: `List` and `Dict` ## List - List provider contains other providers. -- Resolves into list of dependencies. +- Resolves into an immutable sequence of dependencies. ```python import random @@ -16,12 +16,12 @@ class DIContainer(BaseContainer): DIContainer.numbers_sequence.resolve_sync() -# [0.3035656170071561, 0.8280498192037787] +# (0.3035656170071561, 0.8280498192037787) ``` ## Dict - Dict provider is a collection of named providers. -- Resolves into dict of dependencies. +- Resolves into a read-only mapping of dependencies. ```python import random @@ -34,5 +34,5 @@ class DIContainer(BaseContainer): DIContainer.numbers_map.resolve_sync() -# {'key1': 0.6851384528299208, 'key2': 0.41044920948045294} +# mappingproxy({'key1': 0.6851384528299208, 'key2': 0.41044920948045294}) ``` diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index 8aaa327e..87e00c97 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -8,7 +8,7 @@ To interact with both types of contexts, there are two separate interfaces: 1. Use the `container_context()` context manager to interact with the global context and manage `ContextResource` providers. -2. Directly manage a `ContextResource` context by using the `SupportsContext` interface, which both containers +2. Directly manage a `ContextResource` context by using the `SupportsContext` protocol, which both containers and `ContextResource` providers implement. --- @@ -185,7 +185,7 @@ async with container_context(MyContainer.async_resource): ... ``` -It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` interface described +It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` protocol described [here](#quick-reference). ### Context Hierarchy diff --git a/examples/benchmark/RESULTS.md b/examples/benchmark/RESULTS.md index 90ff8dd0..ef211fbc 100644 --- a/examples/benchmark/RESULTS.md +++ b/examples/benchmark/RESULTS.md @@ -3,6 +3,8 @@ Based on this [benchmark](injection.py): | version / iterations | 10^4 | 10^5 | 10^6 | |----------------------|--------|--------|---------| +| 4.0 | 0.0399 | 0.3950 | 3.8107 | +| 3.9.2 | 0.0920 | 0.9804 | 9.1009 | | 3.2.0 | 0.1039 | 1.0563 | 10.3576 | | 3.0.0.a2 | 0.0870 | 0.9430 | 8.8815 | | 3.0.0.a1 | 0.1399 | 1.4136 | 14.1829 | diff --git a/mkdocs.yml b/mkdocs.yml index 67e1bb74..416513c8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,6 +37,7 @@ nav: - Migration: - 1.* to 2.*: migration/v2.md - 2.* to 3.*: migration/v3.md + - 3.* to 4.*: migration/v4.md - Development: - Contributing: dev/contributing.md - Decisions: dev/main-decisions.md diff --git a/pyproject.toml b/pyproject.toml index 2b9359eb..077ab986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dev = [ "mkdocs>=1.6.1", "pytest-randomly", "mkdocs-llmstxt>=0.4.0", - "pydantic<2.12.4" # causes tests to fail because of: TypeError: _eval_type() got an unexpected keyword argument 'prefer_fwd_module' + "pyrefly>=0.61.1", ] [build-system] @@ -62,6 +62,7 @@ module-root = "" python_version = "3.10" strict = true + [tool.ruff] fix = true unsafe-fixes = true diff --git a/tests/experimental/test_container_2.py b/tests/experimental/test_container_2.py index f5a9ec46..30b11561 100644 --- a/tests/experimental/test_container_2.py +++ b/tests/experimental/test_container_2.py @@ -22,6 +22,12 @@ def __hash__(self) -> int: return 0 # pragma: nocover +class _BrokenContextObject(providers.Object[int]): + def get_scope(self) -> ContextScopes | None: + msg = "_missing_scope" + raise AttributeError(msg) + + async def _async_creator() -> AsyncIterator[float]: yield random.random() @@ -30,6 +36,9 @@ def _sync_creator() -> Iterator[_RandomWrapper]: yield _RandomWrapper() +broken_context_provider = _BrokenContextObject(1) + + class Container2(BaseContainer): """Test Container 2.""" @@ -139,6 +148,13 @@ async def test_lazy_provider_not_implemented() -> None: await lazy_provider.tear_down() +def test_lazy_provider_not_implemented_when_context_method_raises_attribute_error() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.broken_context_provider") + + with pytest.raises(NotImplementedError): + lazy_provider.get_scope() + + def test_lazy_provider_attr_getter() -> None: lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider") with lazy_provider.context_sync(force=True): diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index b1292c63..c8ea886b 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -171,6 +171,7 @@ class _Container(BaseContainer): attr_getter = _Container.child.v + attr_getter._register_arguments() attr_getter._register_arguments() assert attr_getter.resolve_sync() == _Container.parent.resolve_sync() diff --git a/tests/providers/test_base.py b/tests/providers/test_base.py index 1c8edbb2..7a0601cf 100644 --- a/tests/providers/test_base.py +++ b/tests/providers/test_base.py @@ -6,11 +6,12 @@ from typing_extensions import override from that_depends import BaseContainer -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import AbstractProvider, _resolve_arguments, _resolve_arguments_sync from that_depends.providers.local_singleton import ThreadLocalSingleton from that_depends.providers.mixin import SupportsTeardown from that_depends.providers.resources import Resource from that_depends.providers.singleton import AsyncSingleton, Singleton +from that_depends.utils import is_set class DummyProvider(SupportsTeardown, AbstractProvider[int]): @@ -43,6 +44,18 @@ def resolve_sync(self) -> int: return self._instance # pragma: no cover +async def test_resolve_arguments_with_multiple_values() -> None: + provider = DummyProvider() + + assert await _resolve_arguments((provider, "value"), (True, False)) == [1, "value"] + + +def test_resolve_arguments_sync_with_multiple_values() -> None: + provider = DummyProvider() + + assert _resolve_arguments_sync((provider, "value"), (True, False)) == [1, "value"] + + def test_add_child_provider() -> None: provider_a = DummyProvider() provider_b = DummyProvider() @@ -75,6 +88,28 @@ def test_register_with_mixed_items() -> None: assert parent in child_1._children, "Expected child_1._children to contain parent" +def test_invalidate_scope_init_order_handles_duplicate_descendants() -> None: + root = DummyProvider() + left = DummyProvider() + right = DummyProvider() + shared = DummyProvider() + + root.add_child_provider(left) + root.add_child_provider(right) + left.add_child_provider(shared) + right.add_child_provider(shared) + + for provider in (root, left, right, shared): + provider._scope_context_init_order = () + provider._scope_init_order = () + + root._invalidate_scope_init_order() + + for provider in (root, left, right, shared): + assert provider._scope_context_init_order is None + assert provider._scope_init_order is None + + def test_sync_tear_down_propagation() -> None: parent = DummyProvider() child_1 = DummyProvider() @@ -168,6 +203,16 @@ def test_thread_local_singleton_registration_and_deregistration(dummy_singleton: assert thread_local not in dummy_singleton._children, "ThreadLocalSingleton should be deregistered after teardown." +def test_get_scope_init_order_is_cached_and_includes_parents(dummy_singleton: Singleton[int]) -> None: + singleton = Singleton(lambda value: value + 1, dummy_singleton.cast) + + first = singleton._get_scope_init_order() + second = singleton._get_scope_init_order() + + assert first == (dummy_singleton, singleton) + assert second is first + + def test_resource_registration_and_deregistration(dummy_singleton: Singleton[int]) -> None: resource = Resource(_resource_generator, dummy_singleton.cast) @@ -233,7 +278,7 @@ def test_propagate_off() -> None: parent.tear_down_sync(propagate=False) assert child in parent._children - assert child._instance is not None + assert is_set(child._instance) async def test_async_tear_down_propagation_with_singleton() -> None: @@ -244,7 +289,7 @@ async def test_async_tear_down_propagation_with_singleton() -> None: await parent.tear_down() - assert child._instance is None + assert not is_set(child._instance) async def test_async_propagate_off() -> None: @@ -255,7 +300,7 @@ async def test_async_propagate_off() -> None: await parent.tear_down(propagate=False) - assert child._instance is not None + assert is_set(child._instance) async def test_provider_registration_in_different_scope_async() -> None: diff --git a/tests/providers/test_collections.py b/tests/providers/test_collections.py index 14a58ae3..5d028a35 100644 --- a/tests/providers/test_collections.py +++ b/tests/providers/test_collections.py @@ -28,7 +28,9 @@ async def test_list_provider() -> None: sync_resource = await DIContainer.sync_resource() async_resource = await DIContainer.async_resource() - assert sequence == [sync_resource, async_resource] + assert sequence == (sync_resource, async_resource) + with pytest.raises(TypeError): + typing.cast(typing.Any, sequence)[0] = sync_resource def test_list_failed_sync_resolve() -> None: @@ -48,6 +50,8 @@ async def test_dict_provider() -> None: assert mapping == {"sync_resource": sync_resource, "async_resource": async_resource} assert mapping == DIContainer.mapping.resolve_sync() + with pytest.raises(TypeError): + typing.cast(typing.Any, mapping)["sync_resource"] = sync_resource @pytest.mark.parametrize("provider", [DIContainer.sequence, DIContainer.mapping]) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 9420097f..02a39c8c 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -24,6 +24,7 @@ from that_depends.meta import DefaultScopeNotDefinedError from that_depends.providers import DIContextMiddleware, container_context from that_depends.providers.context_resources import InvalidContextError, _enter_named_scope, fetch_context_item_by_type +from that_depends.utils import is_set logger = logging.getLogger(__name__) @@ -186,14 +187,15 @@ async def test_early_exit_of_container_context() -> None: async def test_resource_context_early_teardown() -> None: context: ResourceContext[str] = ResourceContext(is_async=True) - assert context.context_stack is None + assert context.is_async is True + assert not is_set(context.context_stack) context.tear_down_sync() - assert context.context_stack is None + assert not is_set(context.context_stack) async def test_teardown_sync_container_context_with_async_resource() -> None: resource_context: ResourceContext[typing.Any] = ResourceContext(is_async=True) - resource_context.context_stack = AsyncExitStack() + resource_context.set_context_state(context_stack=AsyncExitStack()) message = "Cannot tear down async context in sync mode" with pytest.raises(RuntimeError, match=message): resource_context.tear_down_sync() @@ -563,6 +565,56 @@ def _explicit_injected(val: str = Provide[sync_context_resource]) -> str: assert _explicit_injected() != _explicit_injected() +def test_sync_context_resource_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope) + + with _enter_named_scope(current_scope), resource.context_sync(): + assert isinstance(resource.resolve_sync(), str) + + +async def test_async_context_resource_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_async_context_resource).with_config(configured_scope) + + with _enter_named_scope(current_scope): + async with resource.context_async(): + assert isinstance(await resource.resolve(), str) + + +async def test_async_context_resource_strict_scope_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_async_context_resource).with_config(configured_scope, strict_scope=True) + + with _enter_named_scope(current_scope): + async with resource.context_async(force=True): + assert isinstance(await resource.resolve(), str) + + +def test_sync_context_resource_strict_scope_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope, strict_scope=True) + + with _enter_named_scope(current_scope), resource.context_sync(force=True): + assert isinstance(resource.resolve_sync(), str) + + +def test_container_context_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + sync_context_resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope) + + with container_context(_Container, scope=current_scope): + assert isinstance(_Container.sync_context_resource.resolve_sync(), str) + + async def test_async_context_resource_with_dependent_container() -> None: """Container should initialize async context resource for dependent containers.""" async with DIContainer.context_async(): @@ -590,9 +642,18 @@ def test_enter_sync_context_for_async_resource_should_throw( async_context_resource._enter_context_sync() -def test_exit_sync_context_before_enter_should_throw(sync_context_resource: providers.ContextResource[str]) -> None: - with pytest.raises(RuntimeError): - sync_context_resource._exit_context_sync() +def test_exit_sync_context_before_enter_should_throw() -> None: + provider = providers.ContextResource(create_sync_context_resource) + + with pytest.raises(RuntimeError, match=r"Context is not set, call ``_enter_sync_context`` first"): + provider._exit_context_sync() + + +def test_context_sync_manager_exit_before_enter_should_throw( + sync_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError, match=r"Context is not set, call ``__enter__`` first"): + sync_context_resource.context_sync().__exit__(None, None, None) async def test_exit_async_context_before_enter_should_throw( @@ -609,6 +670,18 @@ def test_enter_sync_context_from_async_resource_should_throw( stack.enter_context(async_context_resource.context_sync()) +def test_enter_and_exit_sync_context_directly(sync_context_resource: providers.ContextResource[str]) -> None: + context = sync_context_resource._enter_context_sync() + + assert isinstance(context, ResourceContext) + assert sync_context_resource.resolve_sync() is not None + + sync_context_resource._exit_context_sync() + + with pytest.raises(RuntimeError, match=r"Context is not set. Use container_context"): + sync_context_resource.resolve_sync() + + async def test_preserve_globals_and_initial_context() -> None: initial_context = {"test_1": "test_1", "test_2": "test_2"} @@ -699,6 +772,11 @@ class _Container(BaseContainer): ... assert _Container.get_scope() is ContextScopes.ANY + class _UnsetScopeContainer(BaseContainer): + default_scope = None + + assert _UnsetScopeContainer.get_scope() is ContextScopes.ANY + class _ScopedContainer(BaseContainer): default_scope = ContextScopes.INJECT diff --git a/tests/providers/test_local_singleton.py b/tests/providers/test_local_singleton.py index a614c958..4f92ddae 100644 --- a/tests/providers/test_local_singleton.py +++ b/tests/providers/test_local_singleton.py @@ -4,10 +4,12 @@ import time import typing from concurrent.futures.thread import ThreadPoolExecutor +from unittest.mock import Mock import pytest from that_depends.providers import AsyncFactory, ThreadLocalSingleton +from that_depends.utils import is_set random.seed(23) @@ -34,7 +36,7 @@ def test_thread_local_singleton_same_thread() -> None: provider.tear_down_sync() - assert provider._instance is None, "Tear down failed: Instance should be reset to None." + assert not is_set(provider._instance), "Tear down failed: Instance should be unset." async def test_async_thread_local_singleton_asyncio() -> None: @@ -48,7 +50,23 @@ async def test_async_thread_local_singleton_asyncio() -> None: await provider.tear_down() - assert provider._instance is None, "Tear down failed: Instance should be reset to None." + assert not is_set(provider._instance), "Tear down failed: Instance should be unset." + + +async def test_thread_local_singleton_reuses_instance_created_while_waiting_on_lock() -> None: + expected_value = 42 + factory = Mock(return_value=1) + provider = ThreadLocalSingleton(factory) + + await provider._asyncio_lock.acquire() + task = asyncio.create_task(provider.resolve()) + await asyncio.sleep(0) + + provider._instance = expected_value + provider._asyncio_lock.release() + + assert await task == expected_value + factory.assert_not_called() def test_thread_local_singleton_different_threads() -> None: diff --git a/tests/providers/test_resources.py b/tests/providers/test_resources.py index a48fc95e..239cd6d3 100644 --- a/tests/providers/test_resources.py +++ b/tests/providers/test_resources.py @@ -14,6 +14,7 @@ create_sync_resource, ) from that_depends import BaseContainer, providers +from that_depends.utils import is_set logger = logging.getLogger(__name__) @@ -189,9 +190,9 @@ def create_resource() -> typing.Iterator[str]: def test_sync_resource_sync_tear_down() -> None: DIContainer.sync_resource.resolve_sync() - assert DIContainer.sync_resource._context.instance is not None + assert is_set(DIContainer.sync_resource._context.instance) DIContainer.sync_resource.tear_down_sync() - assert DIContainer.sync_resource._context.instance is None + assert not is_set(DIContainer.sync_resource._context.instance) async def test_async_resource_sync_tear_down_raises() -> None: diff --git a/tests/providers/test_singleton.py b/tests/providers/test_singleton.py index 7bba7bd0..3be5b911 100644 --- a/tests/providers/test_singleton.py +++ b/tests/providers/test_singleton.py @@ -9,6 +9,7 @@ import pytest from that_depends import BaseContainer, providers +from that_depends.utils import is_set @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) @@ -144,6 +145,16 @@ async def test_async_singleton_override() -> None: assert result == SingletonFactory(dep1="bar") +def test_async_singleton_register_arguments_is_idempotent() -> None: + parent = providers.Singleton(lambda: "foo") + singleton_async = providers.AsyncSingleton(create_async_obj, value=parent.cast) + + singleton_async._register_arguments() + singleton_async._register_arguments() + + assert singleton_async in parent._children + + async def test_async_singleton_asyncio_concurrency() -> None: singleton_async = providers.AsyncSingleton(create_async_obj, "foo") @@ -187,9 +198,9 @@ async def test_async_singleton_teardown() -> None: await singleton_async.resolve() singleton_async.tear_down_sync() - assert singleton_async._instance is None + assert not is_set(singleton_async._instance) await singleton_async.resolve() await singleton_async.tear_down() - assert singleton_async._instance is None + assert not is_set(singleton_async._instance) diff --git a/tests/test_container.py b/tests/test_container.py index f7f4752d..fcc5ac09 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -3,6 +3,7 @@ from tests.container import DIContainer from that_depends import BaseContainer, providers +from that_depends.utils import is_set def _sync_resource() -> typing.Iterator[float]: @@ -15,11 +16,11 @@ async def test_container_sync_teardown() -> None: for provider in DIContainer.providers.values(): if isinstance(provider, providers.Resource): if provider._is_async: - assert provider._context.instance is not None + assert is_set(provider._context.instance) else: - assert provider._context.instance is None + assert not is_set(provider._context.instance) if isinstance(provider, providers.Singleton): - assert provider._instance is None + assert not is_set(provider._instance) async def test_container_tear_down() -> None: @@ -27,9 +28,9 @@ async def test_container_tear_down() -> None: await DIContainer.tear_down() for provider in DIContainer.providers.values(): if isinstance(provider, providers.Resource): - assert provider._context.instance is None + assert not is_set(provider._context.instance) if isinstance(provider, providers.Singleton): - assert provider._instance is None + assert not is_set(provider._instance) async def test_container_sync_tear_down_propagation() -> None: @@ -46,8 +47,8 @@ class _DependentContainer(BaseContainer): DIContainer.tear_down_sync() - assert _DependentContainer.singleton._instance is None - assert _DependentContainer.resource._context.instance is None + assert not is_set(_DependentContainer.singleton._instance) + assert not is_set(_DependentContainer.resource._context.instance) async def test_container_tear_down_propagation() -> None: @@ -74,7 +75,7 @@ class _DependentContainer(BaseContainer): await DIContainer.tear_down() - assert _DependentContainer.async_singleton._instance is None - assert _DependentContainer.async_resource._context.instance is None - assert _DependentContainer.sync_singleton._instance is None - assert _DependentContainer.sync_resource._context.instance is None + assert not is_set(_DependentContainer.async_singleton._instance) + assert not is_set(_DependentContainer.async_resource._context.instance) + assert not is_set(_DependentContainer.sync_singleton._instance) + assert not is_set(_DependentContainer.sync_resource._context.instance) diff --git a/tests/test_injection.py b/tests/test_injection.py index 448f4ecf..ac4f3d5b 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -9,8 +9,24 @@ import pytest from tests import container -from that_depends import BaseContainer, ContextScopes, Provide, container_context, get_current_scope, inject, providers -from that_depends.injection import ContextProviderError, StringProviderDefinition +from that_depends import ( + BaseContainer, + ContextScope, + ContextScopes, + Provide, + container_context, + get_current_scope, + inject, + providers, +) +from that_depends.injection import ( + ContextProviderError, + StringProviderDefinition, + _build_injection_plan, + _DirectInjectionParameter, + _SyncInjectionStack, + _TypedInjectionParameter, +) @pytest.fixture(name="fixture_one") @@ -55,6 +71,61 @@ async def inner( await inner(True, arg2=container.SimpleFactory(dep1="1", dep2=2)) +def test_sync_injection_stack_closes_entered_context_managers() -> None: + events: list[str] = [] + + @contextmanager + def _managed() -> typing.Iterator[str]: + events.append("enter") + try: + yield "value" + finally: + events.append("exit") + + with _SyncInjectionStack() as stack: + assert stack.enter_context(_managed()) == "value" + events.append("body") + + assert events == ["enter", "body", "exit"] + + +def test_build_injection_plan_stores_direct_provider_separately() -> None: + provider = providers.Object(1) + + def _injected(value: providers.Object[int] = provider) -> providers.Object[int]: + return value + + plan = _build_injection_plan(_injected) + + assert _injected(provider) is provider + assert plan.direct_parameters == ( + _DirectInjectionParameter(0, "value", provider, provider._get_scope_context_init_order()), + ) + + +def test_build_injection_plan_stores_annotation_for_type_based_injection() -> None: + def _injected(value: float = Provide()) -> float: + return value + + plan = _build_injection_plan(_injected) + + assert _injected(1.0) == 1.0 + assert plan.typed_parameters == (_TypedInjectionParameter(0, "value", float),) + + +def test_build_injection_plan_stores_string_provider_separately() -> None: + def _injected(value: int = Provide["Container.provider"]) -> int: + return value + + plan = _build_injection_plan(_injected) + + assert _injected(1) == 1 + assert len(plan.string_parameters) == 1 + assert plan.string_parameters[0].argument_index == 0 + assert plan.string_parameters[0].field_name == "value" + assert plan.string_parameters[0].definition._definition == "Container.provider" + + async def test_empty_injection() -> None: @inject async def inner(_: int) -> None: @@ -116,6 +187,15 @@ def inner_gen(_: int) -> typing.Generator[int, None, None]: next(inner_gen(1)) +def test_sync_empty_injection_without_scope_warns() -> None: + @inject(scope=None) + def inner(value: int) -> int: + return value + + with pytest.warns(RuntimeWarning, match=r"Expected injection, but nothing found. Remove @inject decorator."): + assert inner(1) == 1 + + def test_type_check() -> None: @inject async def main(simple_factory: container.SimpleFactory = Provide[container.DIContainer.simple_factory]) -> None: @@ -140,6 +220,20 @@ async def _injected(val: int = Provide[_Container.async_resource]) -> int: await inject(scope=ContextScopes.REQUEST)(_injected)() +async def test_async_injection_with_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + injection_scope = ContextScope("MATCHING_SCOPE") + + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + async_resource = providers.ContextResource(_async_creator).with_config(scope=configured_scope) + + async def _injected(val: int = Provide[_Container.async_resource]) -> int: + return val + + assert await inject(scope=injection_scope)(_injected)() == 1 + + async def test_sync_injection_with_scope() -> None: class _Container(BaseContainer): default_scope = ContextScopes.ANY @@ -156,6 +250,20 @@ def _injected(val: int = Provide[_Container.p_inject]) -> int: inject(scope=ContextScopes.REQUEST)(_injected)() +def test_sync_injection_with_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + injection_scope = ContextScope("MATCHING_SCOPE") + + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + p_inject = providers.ContextResource(_sync_creator).with_config(scope=configured_scope) + + def _injected(val: int = Provide[_Container.p_inject]) -> int: + return val + + assert inject(scope=injection_scope)(_injected)() == 1 + + def test_inject_decorator_should_not_allow_any_scope() -> None: with pytest.raises(ValueError, match=f"{ContextScopes.ANY} is not allowed in inject decorator."): inject(scope=ContextScopes.ANY) @@ -220,6 +328,34 @@ def _injected(val: int = Provide["_Container.sync_resource"]) -> int: assert _injected() == return_value +async def test_async_injection_with_string_provider_definition_respects_explicit_arguments() -> None: + override_value = 10 + + class _Container(BaseContainer): + async_resource = providers.AsyncFactory(_async_creator) + + @inject + async def _injected(val: int = Provide["_Container.async_resource"]) -> int: + return val + + assert await _injected(override_value) == override_value + assert await _injected(val=override_value) == override_value + + +def test_sync_injection_with_string_provider_definition_respects_explicit_arguments() -> None: + override_value = 10 + + class _Container(BaseContainer): + sync_resource = providers.Factory(lambda: 1) + + @inject + def _injected(val: int = Provide["_Container.sync_resource"]) -> int: + return val + + assert _injected(override_value) == override_value + assert _injected(val=override_value) == override_value + + def test_provider_string_definition_with_alias() -> None: return_value = 321 @@ -336,6 +472,46 @@ def _injected(v: float = Provide[_Container.provider_used]) -> float: assert _Container.provider_used.resolve_sync() +async def test_injection_with_string_provider_definition_initializes_parent_context_async() -> None: + async def _async_resource() -> typing.AsyncIterator[float]: + yield random.random() + + class _Container(BaseContainer): + provider_used = providers.ContextResource(_async_resource).with_config(scope=ContextScopes.INJECT) + dependent = providers.Factory(lambda value: value, provider_used.cast) + + @inject + async def _injected(val: float = Provide["_Container.dependent"]) -> float: + assert val == await _Container.dependent.resolve() + assert val == await _Container.provider_used.resolve() + return val + + assert isinstance(await _injected(), float) + + with pytest.raises(RuntimeError): + await _Container.provider_used.resolve() + + +def test_injection_with_string_provider_definition_initializes_parent_context_sync() -> None: + def _sync_resource() -> typing.Iterator[float]: + yield random.random() + + class _Container(BaseContainer): + provider_used = providers.ContextResource(_sync_resource).with_config(scope=ContextScopes.INJECT) + dependent = providers.Factory(lambda value: value, provider_used.cast) + + @inject + def _injected(val: float = Provide["_Container.dependent"]) -> float: + assert val == _Container.dependent.resolve_sync() + assert val == _Container.provider_used.resolve_sync() + return val + + assert isinstance(_injected(), float) + + with pytest.raises(RuntimeError): + _Container.provider_used.resolve_sync() + + async def test_injection_initializes_context_for_parents_async() -> None: async def _async_resource() -> typing.AsyncIterator[float]: yield random.random() @@ -863,6 +1039,20 @@ def _injected_2(val: float = Provide()) -> float: assert isinstance(_injected_2(), float) +def test_injection_by_type_sync_respects_explicit_arguments() -> None: + class _Container(BaseContainer): + sync_resource = providers.Factory(random.random).bind(float) + + override_value = 10.0 + + @inject(container=_Container) + def _injected(val: float = Provide()) -> float: + return val + + assert _injected(override_value) == override_value + assert _injected(val=override_value) == override_value + + async def test_injection_by_type_async() -> None: class _Container(BaseContainer): sync_resource = providers.Factory(random.random).bind(float) @@ -879,6 +1069,20 @@ async def _injected_2(val: float = Provide()) -> float: assert isinstance(await _injected_2(), float) +async def test_injection_by_type_async_respects_explicit_arguments() -> None: + class _Container(BaseContainer): + sync_resource = providers.Factory(random.random).bind(float) + + override_value = 10.0 + + @inject(container=_Container) + async def _injected(val: float = Provide()) -> float: + return val + + assert await _injected(override_value) == override_value + assert await _injected(val=override_value) == override_value + + async def test_injection_by_type_async_generator() -> None: class _Container(BaseContainer): sync_resource = providers.Factory(random.random).bind(float) diff --git a/tests/test_meta.py b/tests/test_meta.py index b4efc149..6f1a9353 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -9,6 +9,13 @@ class _Test(metaclass=BaseContainerMeta): assert _Test.get_scope() == ContextScopes.ANY +def test_base_container_meta_uses_explicit_default_scope() -> None: + class _Test(metaclass=BaseContainerMeta): + default_scope = ContextScopes.INJECT + + assert _Test.get_scope() == ContextScopes.INJECT + + def test_base_container_meta_has_correct_default_name() -> None: class _Test(metaclass=BaseContainerMeta): pass diff --git a/tests/test_multiple_containers.py b/tests/test_multiple_containers.py index 6e74282a..69a2246f 100644 --- a/tests/test_multiple_containers.py +++ b/tests/test_multiple_containers.py @@ -4,6 +4,7 @@ from tests import container from that_depends import BaseContainer, providers +from that_depends.utils import is_set class InnerContainer(BaseContainer): @@ -23,16 +24,16 @@ async def test_included_container() -> None: assert all(isinstance(x, datetime.datetime) for x in sequence) await OuterContainer.tear_down() - assert InnerContainer.sync_resource._context.instance is None - assert InnerContainer.async_resource._context.instance is None + assert not is_set(InnerContainer.sync_resource._context.instance) + assert not is_set(InnerContainer.async_resource._context.instance) await OuterContainer.init_resources() sync_resource_context = InnerContainer.sync_resource._context assert sync_resource_context - assert sync_resource_context.instance is not None + assert is_set(sync_resource_context.instance) async_resource_context = InnerContainer.async_resource._context assert async_resource_context - assert async_resource_context.instance is not None + assert is_set(async_resource_context.instance) await OuterContainer.tear_down() diff --git a/that_depends/container.py b/that_depends/container.py index f42d08a2..d23afe40 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,13 +1,13 @@ import inspect import typing -from contextlib import contextmanager +from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from typing import overload from typing_extensions import override from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, AsyncSingleton, Resource, Singleton -from that_depends.providers.context_resources import ContextScope, ContextScopes +from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes if typing.TYPE_CHECKING: @@ -26,6 +26,42 @@ class BaseContainer(metaclass=BaseContainerMeta): containers: list[type["BaseContainer"]] default_scope: ContextScope | None = ContextScopes.ANY + @classmethod + def get_scope(cls) -> ContextScope | None: + """Return the default scope used by the container.""" + if cls.default_scope is not None: + return cls.default_scope + return ContextScopes.ANY + + @classmethod + @asynccontextmanager + async def context_async(cls, force: bool = False) -> typing.AsyncIterator[None]: + """Enter async contexts for all connected containers and context resources.""" + async with AsyncExitStack() as stack: + for container in cls.get_containers(): + await stack.enter_async_context(container.context_async(force=force)) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource): + await stack.enter_async_context(provider.context_async(force=force)) + yield + + @classmethod + @contextmanager + def context_sync(cls, force: bool = False) -> typing.Iterator[None]: + """Enter sync contexts for all connected containers and sync context resources.""" + with ExitStack() as stack: + for container in cls.get_containers(): + stack.enter_context(container.context_sync(force=force)) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource) and not provider._is_async: # noqa: SLF001 + stack.enter_context(provider.context_sync(force=force)) + yield + + @classmethod + def supports_context_sync(cls) -> bool: + """Indicate that container classes support sync context management.""" + return True + @classmethod @overload def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... diff --git a/that_depends/entities/resource_context.py b/that_depends/entities/resource_context.py index 8e10b6e4..725c0a94 100644 --- a/that_depends/entities/resource_context.py +++ b/that_depends/entities/resource_context.py @@ -7,6 +7,7 @@ from typing_extensions import override from that_depends.providers.mixin import CannotTearDownSyncError, SupportsTeardown +from that_depends.utils import UNSET, Unset, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -26,22 +27,41 @@ def __init__(self, is_async: bool) -> None: For example within a ``async with container_context(Container): ...`` statement. """ - self.instance: T_co | None = None + self.instance: T_co | Unset = UNSET self.asyncio_lock: typing.Final = asyncio.Lock() self.threading_lock: typing.Final = threading.Lock() - self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None + self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | Unset = UNSET self.is_async = is_async + def set_context_state( + self, + *, + instance: T_co | Unset = UNSET, + context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | Unset = UNSET, + ) -> None: + """Set the context state of the resource. + + Args: + instance: instance. + context_stack: stack. + + Returns: + None + + """ + self.instance = instance + self.context_stack = context_stack + @staticmethod def is_context_stack_async( - context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None, + context_stack: object, ) -> typing.TypeGuard[contextlib.AsyncExitStack]: """Check if the context stack is an async context stack.""" return isinstance(context_stack, contextlib.AsyncExitStack) @staticmethod def is_context_stack_sync( - context_stack: contextlib.AsyncExitStack | contextlib.ExitStack, + context_stack: object, ) -> typing.TypeGuard[contextlib.ExitStack]: """Check if the context stack is a sync context stack.""" return isinstance(context_stack, contextlib.ExitStack) @@ -49,27 +69,27 @@ def is_context_stack_sync( @override async def tear_down(self, propagate: bool = True) -> None: """Tear down the async context stack.""" - if self.context_stack is None: + context_stack = self.context_stack + if not is_set(context_stack): return - if self.is_context_stack_async(self.context_stack): - await self.context_stack.aclose() - elif self.is_context_stack_sync(self.context_stack): - self.context_stack.close() - self.context_stack = None - self.instance = None + if self.is_context_stack_async(context_stack): + await context_stack.aclose() + elif self.is_context_stack_sync(context_stack): + context_stack.close() + self.set_context_state(instance=UNSET, context_stack=UNSET) @override def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: """Tear down the sync context stack.""" - if self.context_stack is None: + context_stack = self.context_stack + if not is_set(context_stack): return - if self.is_context_stack_sync(self.context_stack): - self.context_stack.close() - self.context_stack = None - self.instance = None - elif self.is_context_stack_async(self.context_stack): + if self.is_context_stack_sync(context_stack): + context_stack.close() + self.set_context_state(instance=UNSET, context_stack=UNSET) + elif self.is_context_stack_async(context_stack): msg = "Cannot tear down async context in sync mode" if raise_on_async: raise CannotTearDownSyncError(msg) diff --git a/that_depends/experimental/providers.py b/that_depends/experimental/providers.py index c79becf6..6a5b1ec5 100644 --- a/that_depends/experimental/providers.py +++ b/that_depends/experimental/providers.py @@ -7,7 +7,7 @@ from that_depends import ContextScope from that_depends.providers import AbstractProvider -from that_depends.providers.context_resources import CT, SupportsContext +from that_depends.providers.context_resources import CT_co, SupportsContext from that_depends.providers.mixin import SupportsTeardown @@ -55,32 +55,19 @@ def __init__( @typing_extensions.override def get_scope(self) -> ContextScope | None: - provider = self._get_provider() - if isinstance(provider, SupportsContext): - return provider.get_scope() - msg = "Underlying provider does not support context scopes" - raise NotImplementedError(msg) + return typing.cast(ContextScope | None, self._call_context_method("get_scope")) @typing_extensions.override - def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: - provider = self._get_provider() - if isinstance(provider, SupportsContext): - return provider.context_async(force) - msg = "Underlying provider does not support context management" - raise NotImplementedError(msg) + def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT_co]: + return typing.cast(typing.AsyncContextManager[CT_co], self._call_context_method("context_async", force)) @typing_extensions.override - def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: - provider = self._get_provider() - if isinstance(provider, SupportsContext): - return provider.context_sync(force) - msg = "Underlying provider does not support context management" - raise NotImplementedError(msg) + def context_sync(self, force: bool = False) -> typing.ContextManager[CT_co]: + return typing.cast(typing.ContextManager[CT_co], self._call_context_method("context_sync", force)) @typing_extensions.override def supports_context_sync(self) -> bool: - provider = self._get_provider() - return isinstance(provider, SupportsContext) and provider.supports_context_sync() + return typing.cast(bool, self._call_context_method("supports_context_sync")) @typing_extensions.override async def tear_down(self, propagate: bool = True) -> None: @@ -139,6 +126,19 @@ def _get_provider(self) -> AbstractProvider[Any]: self._provider = cast(AbstractProvider[Any], provider) return self._provider + def _call_context_method(self, method_name: str, *args: object) -> object: + provider = self._get_provider() + try: + method = getattr(type(provider), method_name) + except AttributeError as e: + msg = "Underlying provider does not support context management" + raise NotImplementedError(msg) from e + try: + return method(provider, *args) + except AttributeError as e: + msg = "Underlying provider does not support context management" + raise NotImplementedError(msg) from e + @typing_extensions.override async def resolve(self) -> Any: provider = self._get_provider() diff --git a/that_depends/injection.py b/that_depends/injection.py index 3a48dba7..bd020b73 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -1,19 +1,18 @@ -import asyncio -import contextlib import functools import inspect import re -import threading import typing import warnings -from contextlib import AsyncExitStack, ExitStack +from contextlib import AsyncExitStack +from types import TracebackType + +from typing_extensions import Self from that_depends.container import BaseContainer from that_depends.exceptions import TypeNotBoundError from that_depends.meta import BaseContainerMeta -from that_depends.providers import AbstractProvider, ContextResource +from that_depends.providers import AbstractProvider from that_depends.providers.context_resources import ContextScope, ContextScopes, container_context -from that_depends.providers.mixin import ProviderWithArguments class ContextProviderError(Exception): @@ -29,6 +28,107 @@ class ContextProviderError(Exception): ) +class _DirectInjectionParameter(typing.NamedTuple): + argument_index: int + field_name: str + provider: AbstractProvider[typing.Any] + scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] + + +class _StringInjectionParameter(typing.NamedTuple): + argument_index: int + field_name: str + definition: "StringProviderDefinition" + + +class _TypedInjectionParameter(typing.NamedTuple): + argument_index: int + field_name: str + annotation: type[typing.Any] + + +class _InjectionPlan(typing.NamedTuple): + direct_parameters: tuple[_DirectInjectionParameter, ...] + string_parameters: tuple[_StringInjectionParameter, ...] + typed_parameters: tuple[_TypedInjectionParameter, ...] + + +class _SyncInjectionStack: + __slots__ = ("_exit_states",) + + def __init__(self) -> None: + self._exit_states: list[_SupportsClose] = [] + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> typing.Literal[False]: + _ = exc_type, exc_value, traceback + while self._exit_states: + self._exit_states.pop().close() + return False + + def enter_context(self, context_manager: typing.ContextManager[T]) -> T: + value = context_manager.__enter__() + self._exit_states.append(_ContextManagerExitState(context_manager)) + return value + + def push_exit_state(self, exit_state: "_SupportsClose") -> None: + self._exit_states.append(exit_state) + + +class _SupportsClose(typing.Protocol): + def close(self) -> None: ... + + +class _ContextManagerExitState: + __slots__ = ("_context_manager",) + + def __init__(self, context_manager: typing.ContextManager[typing.Any]) -> None: + self._context_manager = context_manager + + def close(self) -> None: + self._context_manager.__exit__(None, None, None) + + +@functools.cache +def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionPlan: + direct_parameters: list[_DirectInjectionParameter] = [] + string_parameters: list[_StringInjectionParameter] = [] + typed_parameters: list[_TypedInjectionParameter] = [] + for index, (field_name, param) in enumerate(inspect.signature(func).parameters.items()): + default = param.default + if isinstance(default, StringProviderDefinition): + string_parameters.append(_StringInjectionParameter(index, field_name, default)) + elif isinstance(default, AbstractProvider): + direct_parameters.append( + _DirectInjectionParameter( + index, + field_name, + default, + default._get_scope_context_init_order(), # noqa: SLF001 + ) + ) + elif isinstance(default, _Provide): + typed_parameters.append( + _TypedInjectionParameter( + index, + field_name, + typing.cast(type[typing.Any], param.annotation), + ) + ) + return _InjectionPlan( + direct_parameters=tuple(direct_parameters), + string_parameters=tuple(string_parameters), + typed_parameters=tuple(typed_parameters), + ) + + @typing.overload def inject(func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... @@ -88,10 +188,11 @@ def _inject( def _inject_to_sync_gen( gen: typing.Callable[P, typing.Generator[T, typing.Any, typing.Any]], ) -> typing.Callable[P, typing.Generator[T, typing.Any, typing.Any]]: + plan = _build_injection_plan(gen) + @functools.wraps(gen) def inner(*args: P.args, **kwargs: P.kwargs) -> typing.Generator[T, typing.Any, typing.Any]: - signature = inspect.signature(gen) - injected, kwargs = _resolve_arguments_sync(signature, scope, container, None, *args, **kwargs) # type: ignore[assignment] + injected, kwargs = _resolve_arguments_sync(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=2) @@ -105,11 +206,11 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> typing.Generator[T, typing.Any, def _inject_to_async_gen( gen: typing.Callable[P, typing.AsyncGenerator[T, typing.Any]], ) -> typing.Callable[P, typing.AsyncGenerator[T, typing.Any]]: + plan = _build_injection_plan(gen) + @functools.wraps(gen) async def inner(*args: P.args, **kwargs: P.kwargs) -> typing.AsyncGenerator[T, typing.Any]: - signature = inspect.signature(gen) - - injected, kwargs = await _resolve_arguments_async(signature, scope, container, None, *args, **kwargs) # type: ignore[assignment] + injected, kwargs = await _resolve_arguments_async(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=1) @@ -122,26 +223,30 @@ async def inner(*args: P.args, **kwargs: P.kwargs) -> typing.AsyncGenerator[T, t def _inject_to_async( func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], ) -> typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]]: + plan = _build_injection_plan(func) + @functools.wraps(func) async def inner(*args: P.args, **kwargs: P.kwargs) -> T: if enter_scope: async with container_context(scope=scope): - return await _resolve_async(func, None, container, *args, **kwargs) + return await _resolve_async(func, plan, None, container, *args, **kwargs) else: - return await _resolve_async(func, scope, container, *args, **kwargs) + return await _resolve_async(func, plan, scope, container, *args, **kwargs) return inner def _inject_to_sync( func: typing.Callable[P, T], ) -> typing.Callable[P, T]: + plan = _build_injection_plan(func) + @functools.wraps(func) def inner(*args: P.args, **kwargs: P.kwargs) -> T: if enter_scope: with container_context(scope=scope): - return _resolve_sync(func, None, container, *args, **kwargs) + return _resolve_sync(func, plan, None, container, *args, **kwargs) else: - return _resolve_sync(func, scope, container, *args, **kwargs) + return _resolve_sync(func, plan, scope, container, *args, **kwargs) return inner @@ -150,131 +255,169 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T: return _inject(func) -_SYNC_SIGNATURE_CACHE: dict[typing.Callable[..., typing.Any], inspect.Signature] = {} -_THREADING_LOCK = threading.Lock() - - async def _resolve_arguments_async( - signature: inspect.Signature, + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None, stack: AsyncExitStack | None, *args: typing.Any, # noqa: ANN401 **kwargs: typing.Any, # noqa: ANN401 ) -> tuple[bool, dict[str, typing.Any]]: - injected = False - context_providers: set[AbstractProvider[typing.Any]] = set() - params = list(signature.parameters.items()) + if not _plan_has_injected_parameters(plan): + return False, kwargs - for i, (field_name, param) in enumerate(params): - default = param.default + context_providers: set[AbstractProvider[typing.Any]] = set() + for direct_parameter in plan.direct_parameters: + if _is_argument_provided(direct_parameter.argument_index, direct_parameter.field_name, args, kwargs): + continue - if i < len(args) or field_name in kwargs: - if isinstance(default, (AbstractProvider, StringProviderDefinition)): - injected = True + if direct_parameter.scope_context_init_order: + await _setup_scope_contexts_async( + direct_parameter.scope_context_init_order, + scope, + stack, + context_providers, + ) + kwargs[direct_parameter.field_name] = await direct_parameter.provider.resolve() + + for string_parameter in plan.string_parameters: + if _is_argument_provided(string_parameter.argument_index, string_parameter.field_name, args, kwargs): continue - if isinstance(default, StringProviderDefinition): - injected = True - resolved_val = await _resolve_provider_with_scope_async(default.provider, scope, stack, context_providers) - kwargs[field_name] = resolved_val - elif isinstance(default, AbstractProvider): - injected = True - resolved_val = await _resolve_provider_with_scope_async(default, scope, stack, context_providers) - kwargs[field_name] = resolved_val + kwargs[string_parameter.field_name] = await _resolve_provider_with_scope_async( + string_parameter.definition.provider, + scope, + stack, + context_providers, + ) - elif isinstance(default, _Provide): - injected = True - if container is None: - raise RuntimeError(_PROVIDE_MESSAGE) - try: - provider = container.get_provider_for_type(signature.parameters[field_name].annotation) - except TypeNotBoundError as e: - msg = f"Type {signature.parameters[field_name].annotation} is not bound to a provider." - raise RuntimeError(msg) from e - kwargs[field_name] = await _resolve_provider_with_scope_async(provider, scope, stack, context_providers) - return injected, kwargs + for typed_parameter in plan.typed_parameters: + if _is_argument_provided(typed_parameter.argument_index, typed_parameter.field_name, args, kwargs): + continue + + provider = _resolve_typed_provider(typed_parameter.annotation, container) + kwargs[typed_parameter.field_name] = await _resolve_provider_with_scope_async( + provider, + scope, + stack, + context_providers, + ) + return True, kwargs def _resolve_arguments_sync( - signature: inspect.Signature, + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None, - stack: contextlib.ExitStack | None, + stack: _SyncInjectionStack | None, *args: typing.Any, # noqa: ANN401 **kwargs: typing.Any, # noqa: ANN401 ) -> tuple[bool, dict[str, typing.Any]]: - injected = False + if not _plan_has_injected_parameters(plan): + return False, kwargs + context_providers: set[AbstractProvider[typing.Any]] = set() - for i, (field_name, param) in enumerate(signature.parameters.items()): - default = param.default - if i < len(args) or field_name in kwargs: - if isinstance(default, (AbstractProvider, StringProviderDefinition, _Provide)): - injected = True + for direct_parameter in plan.direct_parameters: + if _is_argument_provided(direct_parameter.argument_index, direct_parameter.field_name, args, kwargs): continue - if isinstance(default, StringProviderDefinition): - injected = True - kwargs[field_name] = _resolve_provider_with_scope_sync(default.provider, scope, stack, context_providers) - elif isinstance(default, AbstractProvider): - injected = True - kwargs[field_name] = _resolve_provider_with_scope_sync(default, scope, stack, context_providers) - elif isinstance(default, _Provide): - injected = True - if container is None: - raise RuntimeError(_PROVIDE_MESSAGE) - try: - provider = container.get_provider_for_type(signature.parameters[field_name].annotation) - except TypeNotBoundError as e: - msg = f"Type {signature.parameters[field_name].annotation} is not bound to a provider." - raise RuntimeError(msg) from e - kwargs[field_name] = _resolve_provider_with_scope_sync(provider, scope, stack, context_providers) + if direct_parameter.scope_context_init_order: + _setup_scope_contexts_sync( + direct_parameter.scope_context_init_order, + scope, + stack, + context_providers, + ) + kwargs[direct_parameter.field_name] = direct_parameter.provider.resolve_sync() + + for string_parameter in plan.string_parameters: + if _is_argument_provided(string_parameter.argument_index, string_parameter.field_name, args, kwargs): + continue + + kwargs[string_parameter.field_name] = _resolve_provider_with_scope_sync( + string_parameter.definition.provider, + scope, + stack, + context_providers, + ) + + for typed_parameter in plan.typed_parameters: + if _is_argument_provided(typed_parameter.argument_index, typed_parameter.field_name, args, kwargs): + continue + + provider = _resolve_typed_provider(typed_parameter.annotation, container) + kwargs[typed_parameter.field_name] = _resolve_provider_with_scope_sync( + provider, + scope, + stack, + context_providers, + ) + + return True, kwargs + + +def _plan_has_injected_parameters(plan: _InjectionPlan) -> bool: + return bool(plan.direct_parameters or plan.string_parameters or plan.typed_parameters) + + +def _is_argument_provided( + argument_index: int, + field_name: str, + args: tuple[typing.Any, ...], + kwargs: dict[str, typing.Any], +) -> bool: + return argument_index < len(args) or field_name in kwargs + - return injected, kwargs +def _resolve_typed_provider( + annotation: type[typing.Any], + container: BaseContainerMeta | None, +) -> AbstractProvider[typing.Any]: + if container is None: + raise RuntimeError(_PROVIDE_MESSAGE) + try: + return container.get_provider_for_type(annotation) + except TypeNotBoundError as e: + msg = f"Type {annotation} is not bound to a provider." + raise RuntimeError(msg) from e def _resolve_sync( func: typing.Callable[P, T], + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None = None, *args: P.args, **kwargs: P.kwargs, ) -> T: - if func not in _SYNC_SIGNATURE_CACHE: - with _THREADING_LOCK: - _SYNC_SIGNATURE_CACHE[func] = inspect.signature(func) - signature = _SYNC_SIGNATURE_CACHE[func] - - with ExitStack() as stack: - injected, kwargs = _resolve_arguments_sync(signature, scope, container, stack, *args, **kwargs) # type: ignore[assignment] - + if scope is None: + injected, kwargs = _resolve_arguments_sync(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=3) return func(*args, **kwargs) + with _SyncInjectionStack() as stack: + injected, kwargs = _resolve_arguments_sync(plan, scope, container, stack, *args, **kwargs) # type: ignore[assignment] -_SIGNATURE_CACHE: dict[ - typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, typing.Any]], inspect.Signature -] = {} + if not injected: + warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=3) -_ASYNCIO_LOCK = asyncio.Lock() + return func(*args, **kwargs) + + raise RuntimeError # pragma: no cover # to prevent mypy issue async def _resolve_async( func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None = None, *args: P.args, **kwargs: P.kwargs, ) -> T: - if func not in _SIGNATURE_CACHE: - async with _ASYNCIO_LOCK: - _SIGNATURE_CACHE[func] = inspect.signature(func) - signature = _SIGNATURE_CACHE[func] - async with AsyncExitStack() as stack: - injected, kwargs = await _resolve_arguments_async(signature, scope, container, stack, *args, **kwargs) # type: ignore[assignment] + injected, kwargs = await _resolve_arguments_async(plan, scope, container, stack, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=1) @@ -306,30 +449,25 @@ async def _resolve_provider_with_scope_async( ContextProviderError: if the stack is None. """ - await _add_provider_to_stack_async(provider, stack, scope, providers) + scope_context_init_order = provider._get_scope_context_init_order() # noqa: SLF001 + if scope_context_init_order: + await _setup_scope_contexts_async(scope_context_init_order, scope, stack, providers) return await provider.resolve() -async def _add_provider_to_stack_async( - provider: AbstractProvider[T], - stack: AsyncExitStack | None, +async def _setup_scope_contexts_async( + scope_init_order: tuple[AbstractProvider[typing.Any], ...], scope: ContextScope | None, + stack: AsyncExitStack | None, providers: set[AbstractProvider[typing.Any]], ) -> None: - if provider in providers: - return - providers.add(provider) - if not scope: return - if isinstance(provider, ProviderWithArguments): - provider._register_arguments() # noqa: SLF001 - - parents = provider._parents # noqa: SLF001 - for parent in parents: - await _add_provider_to_stack_async(parent, stack, scope, providers) - if isinstance(provider, ContextResource): - provider_scope = provider.get_scope() + for provider in scope_init_order: + if provider in providers: + continue + providers.add(provider) + provider_scope = provider._scope # noqa: SLF001 if provider_scope in (ContextScopes.ANY, scope): if stack is None: msg = ( @@ -343,34 +481,28 @@ async def _add_provider_to_stack_async( def _resolve_provider_with_scope_sync( provider: AbstractProvider[T], scope: ContextScope | None, - stack: ExitStack | None, + stack: _SyncInjectionStack | None, providers: set[AbstractProvider[typing.Any]], ) -> T: - _add_provider_to_stack_sync(provider, stack, scope, providers) + scope_context_init_order = provider._get_scope_context_init_order() # noqa: SLF001 + if scope_context_init_order: + _setup_scope_contexts_sync(scope_context_init_order, scope, stack, providers) return provider.resolve_sync() -def _add_provider_to_stack_sync( - provider: AbstractProvider[T], - stack: ExitStack | None, +def _setup_scope_contexts_sync( + scope_init_order: tuple[AbstractProvider[typing.Any], ...], scope: ContextScope | None, + stack: _SyncInjectionStack | None, providers: set[AbstractProvider[typing.Any]], ) -> None: - if provider in providers: - return - providers.add(provider) - if not scope: return - if isinstance(provider, ProviderWithArguments): - provider._register_arguments() # noqa: SLF001 - - parents = provider._parents # noqa: SLF001 - for parent in parents: - _add_provider_to_stack_sync(parent, stack, scope, providers) - - if isinstance(provider, ContextResource): - provider_scope = provider.get_scope() + for provider in scope_init_order: + if provider in providers: + continue + providers.add(provider) + provider_scope = provider._scope # noqa: SLF001 if provider_scope in (ContextScopes.ANY, scope): if stack is None: msg = ( @@ -378,7 +510,8 @@ def _add_provider_to_stack_sync( f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." ) raise ContextProviderError(msg) - stack.enter_context(provider.context_sync(force=True)) + _, exit_state = provider._enter_injection_context_sync(force=True) # noqa: SLF001 + stack.push_exit_state(exit_state) class StringProviderDefinition: diff --git a/that_depends/integrations/faststream.py b/that_depends/integrations/faststream.py index 67d9fc64..2120f5bf 100644 --- a/that_depends/integrations/faststream.py +++ b/that_depends/integrations/faststream.py @@ -22,7 +22,7 @@ class DIContextMiddleware(BaseMiddleware): def __init__( self, - *context_items: SupportsContext[Any], + *context_items: SupportsContext[typing.Any], msg: AnyMsg | None = None, context: Optional["ContextRepo"] = None, global_context: dict[str, Any] | Unset = UNSET, @@ -31,7 +31,7 @@ def __init__( """Initialize the container context middleware. Args: - *context_items (SupportsContext[Any]): Context items to initialize. + *context_items (SupportsContext[Any]): Context-capable providers or container classes to initialize. msg (Any): Message object. context (ContextRepo): Context repository. global_context (dict[str, Any] | Unset): Global context to initialize the container. @@ -40,7 +40,7 @@ def __init__( """ super().__init__(msg, context=context) # type: ignore[arg-type] self._context: container_context | None = None - self._context_items = set(context_items) + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context = global_context self._scope = scope @@ -79,8 +79,8 @@ def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # return DIContextMiddleware( *self._context_items, - msg=msg, - context=context, + msg=msg, # type:ignore[unexpected-keyword] + context=context, # type:ignore[unexpected-keyword] scope=self._scope, global_context=self._global_context, ) @@ -93,21 +93,21 @@ class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef] def __init__( self, - *context_items: SupportsContext[Any], + *context_items: SupportsContext[typing.Any], global_context: dict[str, Any] | Unset = UNSET, scope: ContextScope | Unset = UNSET, ) -> None: """Initialize the container context middleware. Args: - *context_items (SupportsContext[Any]): Context items to initialize. + *context_items (SupportsContext[Any]): Context-capable providers or container classes to initialize. global_context (dict[str, Any] | Unset): Global context to initialize the container. scope (ContextScope | Unset): Context scope to initialize the container. """ super().__init__() # type: ignore[call-arg] self._context: container_context | None = None - self._context_items = set(context_items) + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context = global_context self._scope = scope diff --git a/that_depends/meta.py b/that_depends/meta.py index 21507a49..614479d0 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -2,7 +2,6 @@ import typing import warnings from collections.abc import MutableMapping -from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from threading import Lock from typing import TYPE_CHECKING @@ -11,7 +10,7 @@ import that_depends from that_depends.exceptions import TypeNotBoundError from that_depends.providers import AbstractProvider, Resource -from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes, SupportsContext +from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes from that_depends.providers.mixin import SupportsTeardown @@ -43,41 +42,15 @@ def __setitem__(self, key: str, value: typing.Any) -> None: super().__setitem__(key, value) -class BaseContainerMeta(abc.ABCMeta, SupportsContext[None]): +class BaseContainerMeta(abc.ABCMeta): """Metaclass for BaseContainer.""" - @override def get_scope(cls) -> ContextScope | None: + """Return the default scope used by the container.""" if scope := getattr(cls, "default_scope", None): return typing.cast(ContextScope | None, scope) return ContextScopes.ANY - @asynccontextmanager - @override - async def context_async(cls, force: bool = False) -> typing.AsyncIterator[None]: - async with AsyncExitStack() as stack: - for container in cls.get_containers(): - await stack.enter_async_context(container.context_async(force=force)) - for provider in cls.get_providers().values(): - if isinstance(provider, ContextResource): - await stack.enter_async_context(provider.context_async(force=force)) - yield - - @contextmanager - @override - def context_sync(cls, force: bool = False) -> typing.Iterator[None]: - with ExitStack() as stack: - for container in cls.get_containers(): - stack.enter_context(container.context_sync(force=force)) - for provider in cls.get_providers().values(): - if isinstance(provider, ContextResource) and not provider._is_async: # noqa: SLF001 - stack.enter_context(provider.context_sync(force=force)) - yield - - @override - def supports_context_sync(cls) -> bool: - return True - _instances: typing.ClassVar[dict[str, type["BaseContainer"]]] = {} _MUTABLE_ATTRS = ( @@ -88,6 +61,7 @@ def supports_context_sync(cls) -> bool: "containers", "alias", "default_scope", + "type_provider_cache", ) _lock: Lock = Lock() @@ -114,10 +88,6 @@ def name(cls) -> str: def __prepare__(cls, name: str, bases: tuple[type, ...], /, **kwds: typing.Any) -> MutableMapping[str, object]: return _ContainerMetaDict() - @override - def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type: - return super().__new__(cls, name, bases, namespace) - @classmethod def get_instances(cls) -> dict[str, type["BaseContainer"]]: """Get all instances that inherit from BaseContainer.""" @@ -163,12 +133,18 @@ def get_provider_for_type(cls, t: type[T]) -> AbstractProvider[T]: Provider for the given type. """ + if not hasattr(cls, "type_provider_cache"): + cls.type_provider_cache: dict[type[typing.Any], AbstractProvider[typing.Any]] = {} + if provider := cls.type_provider_cache.get(t): + return typing.cast(AbstractProvider[T], provider) for provider in cls.get_providers().values(): if provider._has_contravariant_bindings: # noqa: SLF001 for bind in provider._bindings: # noqa: SLF001 if issubclass(bind, t): + cls.type_provider_cache[t] = provider return provider elif t in provider._bindings: # noqa: SLF001 + cls.type_provider_cache[t] = provider return provider msg = f"Type {t} is not bound to any provider in container {cls.name()}" raise TypeNotBoundError(msg) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 9c30bdfe..7adc02d1 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -11,6 +11,7 @@ from that_depends.entities.resource_context import ResourceContext from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown +from that_depends.utils import UNSET, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -18,8 +19,71 @@ P = typing.ParamSpec("P") ResourceCreatorType: typing.TypeAlias = typing.Callable[ P, - typing.Iterator[T_co] | typing.AsyncIterator[T_co] | typing.ContextManager[T_co] | typing.AsyncContextManager[T_co], + typing.Iterator[T_co] + | typing.AsyncIterator[T_co] + | contextlib.AbstractContextManager[T_co] + | contextlib.AbstractAsyncContextManager[T_co], ] +_EMPTY_ARGS: typing.Final[tuple[()]] = () +_EMPTY_KWARGS: typing.Final[dict[str, typing.Any]] = {} + + +async def _resolve_arguments( + args: tuple[typing.Any, ...], + args_are_providers: tuple[bool, ...], +) -> tuple[typing.Any, ...] | list[typing.Any]: + if not args: + return _EMPTY_ARGS + if len(args) == 1: + arg = args[0] + return (await arg.resolve() if args_are_providers[0] else arg,) + return [ + await arg.resolve() if is_provider else arg for arg, is_provider in zip(args, args_are_providers, strict=False) + ] + + +def _resolve_arguments_sync( + args: tuple[typing.Any, ...], + args_are_providers: tuple[bool, ...], +) -> tuple[typing.Any, ...] | list[typing.Any]: + if not args: + return _EMPTY_ARGS + if len(args) == 1: + arg = args[0] + return (arg.resolve_sync() if args_are_providers[0] else arg,) + return [ + arg.resolve_sync() if is_provider else arg for arg, is_provider in zip(args, args_are_providers, strict=False) + ] + + +async def _resolve_keyword_arguments( + kwargs_items: tuple[tuple[str, typing.Any], ...], + kwargs_are_providers: tuple[bool, ...], +) -> dict[str, typing.Any]: + if not kwargs_items: + return _EMPTY_KWARGS + if len(kwargs_items) == 1: + key, value = kwargs_items[0] + return {key: await value.resolve() if kwargs_are_providers[0] else value} + return { + key: await value.resolve() if is_provider else value + for (key, value), is_provider in zip(kwargs_items, kwargs_are_providers, strict=False) + } + + +def _resolve_keyword_arguments_sync( + kwargs_items: tuple[tuple[str, typing.Any], ...], + kwargs_are_providers: tuple[bool, ...], +) -> dict[str, typing.Any]: + if not kwargs_items: + return _EMPTY_KWARGS + if len(kwargs_items) == 1: + key, value = kwargs_items[0] + return {key: value.resolve_sync() if kwargs_are_providers[0] else value} + return { + key: value.resolve_sync() if is_provider else value + for (key, value), is_provider in zip(kwargs_items, kwargs_are_providers, strict=False) + } class AbstractProvider(abc.ABC, typing.Generic[T_co]): @@ -30,7 +94,10 @@ def __init__(self) -> None: super().__init__() self._children: set[AbstractProvider[typing.Any]] = set() self._parents: set[AbstractProvider[typing.Any]] = set() - self._override: typing.Any = None + self._is_context_resource = False + self._scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None + self._scope_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None + self._override: typing.Any = UNSET self._bindings: set[type] = set() self._has_contravariant_bindings: bool = False self._lock = threading.Lock() @@ -62,10 +129,14 @@ def _register(self, candidates: typing.Iterable[typing.Any]) -> None: None """ + changed = False for candidate in candidates: if isinstance(candidate, AbstractProvider): candidate.add_child_provider(self) self._parents.add(candidate) + changed = True + if changed: + self._invalidate_scope_init_order() def _deregister(self, candidates: typing.Iterable[typing.Any]) -> None: """Deregister current provider as child. @@ -77,10 +148,71 @@ def _deregister(self, candidates: typing.Iterable[typing.Any]) -> None: None """ + changed = False for candidate in candidates: if isinstance(candidate, AbstractProvider) and self in candidate._children: # noqa: SLF001 candidate.remove_child_provider(self) self._parents.discard(candidate) + changed = True + if changed: + self._invalidate_scope_init_order() + + def _invalidate_scope_init_order(self) -> None: + stack = [self] + visited: set[AbstractProvider[typing.Any]] = set() + + while stack: + provider = stack.pop() + if provider in visited: + continue + visited.add(provider) + provider._scope_context_init_order = None # noqa: SLF001 + provider._scope_init_order = None # noqa: SLF001 + stack.extend(provider._children) # noqa: SLF001 + + def _get_scope_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]: + if self._scope_init_order is not None: + return self._scope_init_order + + if isinstance(self, ProviderWithArguments): + self._register_arguments() + + ordered: list[AbstractProvider[typing.Any]] = [] + seen: set[AbstractProvider[typing.Any]] = set() + + for parent in self._parents: + for ancestor in parent._get_scope_init_order(): # noqa: SLF001 + if ancestor not in seen: + seen.add(ancestor) + ordered.append(ancestor) + + if self not in seen: + ordered.append(self) + + self._scope_init_order = tuple(ordered) + return self._scope_init_order + + def _get_scope_context_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]: + if self._scope_context_init_order is not None: + return self._scope_context_init_order + + if isinstance(self, ProviderWithArguments): + self._register_arguments() + + ordered: list[AbstractProvider[typing.Any]] = [] + seen: set[AbstractProvider[typing.Any]] = set() + + for parent in self._parents: + for ancestor in parent._get_scope_context_init_order(): # noqa: SLF001 + if ancestor not in seen: + seen.add(ancestor) + ordered.append(ancestor) + + if self._is_context_resource and self not in seen: + ordered.append(self) + + self._scope_context_init_order = tuple(ordered) + return self._scope_context_init_order def add_child_provider(self, provider: "AbstractProvider[typing.Any]") -> None: """Add a child provider to the current provider. @@ -232,7 +364,7 @@ def reset_override_sync( None """ - self._override = None + self._override = UNSET if tear_down_children: eligible_children = [child for child in self._children if isinstance(child, SupportsTeardown)] for child in eligible_children: @@ -249,7 +381,7 @@ async def reset_override(self, tear_down_children: bool = False, propagate: bool None """ - self._override = None + self._override = UNSET if tear_down_children: eligible_children = [child for child in self._children if isinstance(child, SupportsTeardown)] for child in eligible_children: @@ -306,7 +438,10 @@ def __init__( """ super().__init__() - self._creator: typing.Any + self._creator: typing.Callable[ + ..., + contextlib.AbstractContextManager[T_co] | contextlib.AbstractAsyncContextManager[T_co], + ] if inspect.isasyncgenfunction(creator): self._is_async = True @@ -314,10 +449,10 @@ def __init__( elif inspect.isgeneratorfunction(creator): self._is_async = False self._creator = contextlib.contextmanager(creator) - elif isinstance(creator, type) and issubclass(creator, typing.AsyncContextManager): + elif isinstance(creator, type) and hasattr(creator, "__aenter__") and hasattr(creator, "__aexit__"): self._is_async = True self._creator = creator - elif isinstance(creator, type) and issubclass(creator, typing.ContextManager): + elif isinstance(creator, type) and hasattr(creator, "__enter__") and hasattr(creator, "__exit__"): self._is_async = False self._creator = creator else: @@ -325,60 +460,65 @@ def __init__( raise TypeError(msg) self._args = args self._kwargs = kwargs + self._args_are_providers = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items = tuple(kwargs.items()) + self._kwargs_are_providers = tuple(isinstance(value, AbstractProvider) for _, value in self._kwargs_items) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @abc.abstractmethod def _fetch_context(self) -> ResourceContext[T_co]: ... @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) context = self._fetch_context() # lock to prevent race condition while resolving async with context.asyncio_lock: - if context.instance is not None: + if is_set(context.instance): return context.instance self._register_arguments() - - cm: typing.ContextManager[T_co] | typing.AsyncContextManager[T_co] = self._creator( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], - **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, + cm: contextlib.AbstractContextManager[T_co] | contextlib.AbstractAsyncContextManager[T_co] = self._creator( + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) - if isinstance(cm, typing.AsyncContextManager): + if isinstance(cm, contextlib.AbstractAsyncContextManager): context.context_stack = contextlib.AsyncExitStack() context.instance = await context.context_stack.enter_async_context(cm) - elif isinstance(cm, typing.ContextManager): + elif isinstance(cm, contextlib.AbstractContextManager): context.context_stack = contextlib.ExitStack() context.instance = context.context_stack.enter_context(cm) else: # pragma: no cover - typing.assert_never(cm) + typing_extensions.assert_never(cm) return context.instance @override def resolve_sync(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) context = self._fetch_context() # lock to prevent race condition while resolving with context.threading_lock: - if context.instance is not None: + if is_set(context.instance): return context.instance if self._is_async: @@ -386,15 +526,15 @@ def resolve_sync(self) -> T_co: raise RuntimeError(msg) self._register_arguments() - cm = self._creator( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) context.context_stack = contextlib.ExitStack() - context.instance = context.context_stack.enter_context(cm) + instance: T_co = context.context_stack.enter_context(cm) # type: ignore[arg-type] + context.instance = instance - return context.instance + return instance def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401 @@ -409,6 +549,8 @@ class AttrGetter( """Provides an attribute after resolving the wrapped provider.""" def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return if isinstance(self._provider, ProviderWithArguments): self._provider._register_arguments() # noqa: SLF001 self._parents = self._provider._parents # noqa: SLF001 @@ -417,6 +559,7 @@ def _deregister_arguments(self) -> None: if isinstance(self._provider, ProviderWithArguments): self._provider._deregister_arguments() # noqa: SLF001 self._parents = self._provider._parents # noqa: SLF001 + self._reset_arguments_registration() __slots__ = "_attrs", "_provider" @@ -433,11 +576,11 @@ def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None: self._attrs = [attr_name] @override - def __getattr__(self, attr: str) -> "AttrGetter[T_co]": - if attr.startswith("_"): - msg = f"'{type(self)}' object has no attribute '{attr}'" + def __getattr__(self, attr_name: str) -> "AttrGetter[T_co]": + if attr_name.startswith("_"): + msg = f"'{type(self)}' object has no attribute '{attr_name}'" raise AttributeError(msg) - self._attrs.append(attr) + self._attrs.append(attr_name) return self @override diff --git a/that_depends/providers/collection.py b/that_depends/providers/collection.py index 28533d11..08e0a4ca 100644 --- a/that_depends/providers/collection.py +++ b/that_depends/providers/collection.py @@ -1,4 +1,6 @@ import typing +from collections.abc import Mapping, Sequence +from types import MappingProxyType from typing_extensions import override @@ -8,10 +10,11 @@ T_co = typing.TypeVar("T_co", covariant=True) -class List(AbstractProvider[list[T_co]]): - """Provides multiple resources as a list. +class List(AbstractProvider[Sequence[T_co]]): + """Provides multiple resources as a read-only sequence. - The `List` provider resolves multiple dependencies into a list. + The `List` provider resolves multiple dependencies into an immutable tuple + exposed through the `Sequence` interface. Example: ```python @@ -24,12 +27,12 @@ class List(AbstractProvider[list[T_co]]): # Synchronous resolution resolved_list = list_provider.resolve_sync() - print(resolved_list) # Output: [1, 2] + print(tuple(resolved_list)) # Output: (1, 2) # Asynchronous resolution import asyncio resolved_list_async = asyncio.run(list_provider.resolve()) - print(resolved_list_async) # Output: [1, 2] + print(tuple(resolved_list_async)) # Output: (1, 2) ``` """ @@ -53,22 +56,24 @@ def __getattr__(self, attr_name: str) -> typing.Any: raise AttributeError(msg) @override - async def resolve(self) -> list[T_co]: - return [await x.resolve() for x in self._providers] + async def resolve(self) -> Sequence[T_co]: + resolved = [await x.resolve() for x in self._providers] + return tuple(resolved) @override - def resolve_sync(self) -> list[T_co]: - return [x.resolve_sync() for x in self._providers] + def resolve_sync(self) -> Sequence[T_co]: + resolved = [x.resolve_sync() for x in self._providers] + return tuple(resolved) @override - async def __call__(self) -> list[T_co]: + async def __call__(self) -> Sequence[T_co]: return await self.resolve() -class Dict(AbstractProvider[dict[str, T_co]]): - """Provides multiple resources as a dictionary. +class Dict(AbstractProvider[Mapping[str, T_co]]): + """Provides multiple resources as a read-only mapping. - The `Dict` provider resolves multiple named dependencies into a dictionary. + The `Dict` provider resolves multiple named dependencies into a read-only mapping. Example: ```python @@ -81,12 +86,12 @@ class Dict(AbstractProvider[dict[str, T_co]]): # Synchronous resolution resolved_dict = dict_provider.resolve_sync() - print(resolved_dict) # Output: {"key1": 1, "key2": 2} + print(dict(resolved_dict)) # Output: {"key1": 1, "key2": 2} # Asynchronous resolution import asyncio resolved_dict_async = asyncio.run(dict_provider.resolve()) - print(resolved_dict_async) # Output: {"key1": 1, "key2": 2} + print(dict(resolved_dict_async)) # Output: {"key1": 1, "key2": 2} ``` """ @@ -110,9 +115,9 @@ def __getattr__(self, attr_name: str) -> typing.Any: raise AttributeError(msg) @override - async def resolve(self) -> dict[str, T_co]: - return {key: await provider.resolve() for key, provider in self._providers.items()} + async def resolve(self) -> Mapping[str, T_co]: + return MappingProxyType({key: await provider.resolve() for key, provider in self._providers.items()}) @override - def resolve_sync(self) -> dict[str, T_co]: - return {key: provider.resolve_sync() for key, provider in self._providers.items()} + def resolve_sync(self) -> Mapping[str, T_co]: + return MappingProxyType({key: provider.resolve_sync() for key, provider in self._providers.items()}) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 4dab61ab..549b3cd2 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -3,7 +3,6 @@ import inspect import logging import typing -from abc import abstractmethod from collections.abc import Iterable from contextlib import AbstractAsyncContextManager, AbstractContextManager from contextvars import ContextVar, Token @@ -11,10 +10,11 @@ from types import TracebackType from typing import Final, overload -from typing_extensions import TypeIs, override +from typing_extensions import Protocol, TypeIs, override, runtime_checkable from that_depends.entities.resource_context import ResourceContext from that_depends.providers.base import AbstractResource +from that_depends.utils import UNSET if typing.TYPE_CHECKING: @@ -41,6 +41,55 @@ class InvalidContextError(RuntimeError): """Raised when an invalid context is being used.""" +class _SyncInjectionExitState(typing.Generic[T_co]): + __slots__ = ("_context", "_context_item", "_token") + + def __init__( + self, + context: ContextVar[ResourceContext[T_co]], + context_item: ResourceContext[T_co], + token: Token[ResourceContext[T_co]], + ) -> None: + self._context = context + self._context_item = context_item + self._token = token + + def close(self) -> None: + context_stack = self._context_item.context_stack + if self._context_item.is_context_stack_sync(context_stack): + context_stack.close() # type: ignore[union-attr] + self._context_item.context_stack = UNSET + self._context_item.instance = UNSET + self._context.reset(self._token) + + +class _SyncContextResourceContext(contextlib.ContextDecorator, AbstractContextManager[ResourceContext[T_co]]): + __slots__ = ("_exit_state", "_force", "_provider") + + def __init__(self, provider: "ContextResource[T_co]", force: bool) -> None: + self._provider = provider + self._force = force + self._exit_state: _SyncInjectionExitState[T_co] | None = None + + @override + def __enter__(self) -> ResourceContext[T_co]: + value, self._exit_state = self._provider._enter_injection_context_sync(force=self._force) # noqa: SLF001 + return value + + @override + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if self._exit_state is None: + msg = "Context is not set, call ``__enter__`` first" + raise RuntimeError(msg) + _ = exc_type, exc_value, traceback + self._exit_state.close() + + class ContextScope: """A named context scope.""" @@ -102,22 +151,21 @@ def _enter_named_scope(scope: ContextScope) -> typing.Iterator[ContextScope]: T = typing.TypeVar("T") -CT = typing.TypeVar("CT") +CT_co = typing.TypeVar("CT_co", covariant=True) -class SupportsContext(typing.Generic[CT]): +@runtime_checkable +class SupportsContext(Protocol[CT_co]): """Interface for resources that support context initialization. This interface defines methods to create synchronous and asynchronous context managers, as well as a function decorator for context initialization. """ - @abstractmethod def get_scope(self) -> ContextScope | None: """Return the scope of the resource.""" - @abstractmethod - def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: + def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT_co]: """Create an async context manager for this resource. Args: @@ -133,9 +181,9 @@ def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: ``` """ + ... - @abstractmethod - def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: + def context_sync(self, force: bool = False) -> typing.ContextManager[CT_co]: """Create a sync context manager for this resource. Args: @@ -151,8 +199,8 @@ def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: ``` """ + ... - @abstractmethod def supports_context_sync(self) -> bool: """Check whether the resource supports sync context. @@ -160,6 +208,10 @@ def supports_context_sync(self) -> bool: bool: True if sync context is supported, False otherwise. """ + ... + + +BaseContainerType: typing.TypeAlias = type["BaseContainer"] def _get_container_context() -> dict[str, typing.Any] | None: @@ -228,16 +280,20 @@ class ContextResource( @override async def resolve(self) -> T_co: + if not self._strict_scope or self._scope == ContextScopes.ANY: + return await super().resolve() current_scope = get_current_scope() - if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope): + if self._scope == current_scope: return await super().resolve() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @override def resolve_sync(self) -> T_co: + if not self._strict_scope or self._scope == ContextScopes.ANY: + return super().resolve_sync() current_scope = get_current_scope() - if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope): + if self._scope == current_scope: return super().resolve_sync() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @@ -274,7 +330,8 @@ def __init__( """ super().__init__(creator, *args, **kwargs) - self._from_creator = creator + self._from_creator: typing.Callable[..., typing.Iterator[T_co] | typing.AsyncIterator[T_co]] = creator + self._is_context_resource = True self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") self._token: Token[ResourceContext[T_co]] | None = None self._async_lock: Final = asyncio.Lock() @@ -334,7 +391,7 @@ def with_config(self, scope: ContextScope | None, strict_scope: bool = False) -> if strict_scope and scope == ContextScopes.ANY: msg = f"Cannot set strict_scope with scope {scope}." raise ValueError(msg) - r = ContextResource(self._from_creator, *self._args, **self._kwargs) # type: ignore[arg-type] + r = ContextResource(self._from_creator, *self._args, **self._kwargs) r._scope = scope r._strict_scope = strict_scope @@ -350,24 +407,48 @@ def _enter_context_sync(self, force: bool = False) -> ResourceContext[T_co]: raise RuntimeError(msg) return self._enter(force) + def _enter_injection_context_sync( + self, + force: bool = False, + ) -> tuple[ResourceContext[T_co], _SyncInjectionExitState[T_co]]: + if self._is_async: + msg = "Please use async context instead." + raise RuntimeError(msg) + if not force and self._scope != ContextScopes.ANY: + current_scope = get_current_scope() + if self._scope != current_scope: + msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" + raise InvalidContextError(msg) + + context_item: ResourceContext[T_co] = ResourceContext(is_async=False) + token = self._context.set(context_item) + return context_item, _SyncInjectionExitState(self._context, context_item, token) + async def _enter_context_async(self, force: bool = False) -> ResourceContext[T_co]: return self._enter(force) def _enter(self, force: bool = False) -> ResourceContext[T_co]: - if not force and self._scope not in (ContextScopes.ANY, get_current_scope()): - msg = f"Cannot enter context for resource with scope {self._scope} in scope {get_current_scope()!r}" - raise InvalidContextError(msg) - self._token = self._context.set(ResourceContext(is_async=self._is_async)) - return self._context.get() + if not force and self._scope != ContextScopes.ANY: + current_scope = get_current_scope() + if self._scope != current_scope: + msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" + raise InvalidContextError(msg) + context_item: ResourceContext[T_co] = ResourceContext(is_async=self._is_async) + self._token = self._context.set(context_item) + return context_item def _exit_context_sync(self) -> None: - if not self._token: + if self._token is None: msg = "Context is not set, call ``_enter_sync_context`` first" raise RuntimeError(msg) try: context_item = self._context.get() - context_item.tear_down_sync() + context_stack = context_item.context_stack + if context_item.is_context_stack_sync(context_stack): + context_stack.close() # type: ignore[union-attr] + context_item.context_stack = UNSET + context_item.instance = UNSET finally: self._context.reset(self._token) @@ -385,21 +466,9 @@ async def _exit_context_async(self) -> None: finally: self._context.reset(self._token) - @contextlib.contextmanager @override - def context_sync(self, force: bool = False) -> typing.Iterator[ResourceContext[T_co]]: - if self._is_async: - msg = "Please use async context instead." - raise RuntimeError(msg) - token = self._token - with self._lock: - val = self._enter_context_sync(force=force) - temp_token = self._token - yield val - with self._lock: - self._token = temp_token - self._exit_context_sync() - self._token = token + def context_sync(self, force: bool = False) -> _SyncContextResourceContext[T_co]: + return _SyncContextResourceContext(self, force) @contextlib.asynccontextmanager @override @@ -423,9 +492,6 @@ def _fetch_context(self) -> ResourceContext[T_co]: raise RuntimeError(msg) from e -ContainerType = typing.TypeVar("ContainerType", bound="type[BaseContainer]") - - class container_context(AbstractContextManager[ContextType], AbstractAsyncContextManager[ContextType]): # noqa: N801 """Initialize contexts for the provided containers or resources. @@ -434,15 +500,16 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex """ __slots__ = ( - "_containers", + "_container_items", + "_container_providers_by_scope", "_context_items", - "_context_providers", "_context_stack", "_context_token", + "_direct_context_items", + "_entered_context_items", "_global_context", "_initial_context", "_preserve_global_context", - "_providers", "_reset_resource_context", "_scope", "_scope_token", @@ -458,7 +525,7 @@ def __init__( """Initialize a new container context. Args: - *context_items (SupportsContext[Any]): Context items to initialize a new context for. + *context_items: Context-capable providers or container classes to initialize. global_context (dict[str, Any] | None): A dictionary representing the global context. preserve_global_context (bool): If True, merges the existing global context with the new one. scope (ContextScope | None): The named scope that should be initialized. @@ -481,13 +548,44 @@ def __init__( self._global_context = global_context self._context_token: Token[ContextType] | None = None self._context_items: typing.Final[set[SupportsContext[typing.Any]]] = set(context_items) - self._context_providers: set[ContextResource[typing.Any]] = set() + self._container_items: tuple[type[BaseContainer], ...] + self._direct_context_items: tuple[SupportsContext[typing.Any], ...] + self._container_providers_by_scope: dict[ContextScope | None, tuple[ContextResource[typing.Any], ...]] = {} + self._entered_context_items: tuple[SupportsContext[typing.Any], ...] = () self._reset_resource_context: typing.Final[bool] = bool(scope) self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None self._scope_token: Token[ContextScope | None] | None = None + self._container_items, self._direct_context_items = self._parse_context_items(self._context_items) + + def _parse_context_items( + self, + context_items: set[SupportsContext[typing.Any]], + ) -> tuple[tuple[type["BaseContainer"], ...], tuple[SupportsContext[typing.Any], ...]]: + from that_depends.container import BaseContainer # noqa: PLC0415 + + containers: list[type[BaseContainer]] = [] + direct_items: list[SupportsContext[typing.Any]] = [] + for item in context_items: + if isinstance(item, type) and issubclass(item, BaseContainer): + containers.append(item) + else: + direct_items.append(item) + return tuple(containers), tuple(direct_items) - def _resolve_initial_conditions(self) -> None: - self._scope = self._scope or get_current_scope() + def _get_context_providers_for_scope( + self, + scope: ContextScope | None, + ) -> tuple[ContextResource[typing.Any], ...]: + cached = self._container_providers_by_scope.get(scope) + if cached is None: + providers: set[ContextResource[typing.Any]] = set() + self._add_providers_from_containers(self._container_items, providers, scope) + cached = tuple(providers) + self._container_providers_by_scope[scope] = cached + return cached + + def _resolve_initial_conditions(self) -> ContextScope | None: + scope = self._scope or get_current_scope() if self._preserve_global_context and self._global_context: if context := _get_container_context(): self._initial_context = {**context, **self._global_context} @@ -499,44 +597,48 @@ def _resolve_initial_conditions(self) -> None: ) else: self._initial_context = self._global_context or {} + entered_context_items = dict.fromkeys(self._direct_context_items) + for provider in self._get_context_providers_for_scope(scope): + entered_context_items[provider] = provider if self._reset_resource_context: from that_depends.meta import BaseContainerMeta # noqa: PLC0415 - self._add_providers_from_containers(BaseContainerMeta.get_instances().values(), self._scope) - for item in self._context_items: - from that_depends.container import BaseContainer # noqa: PLC0415 - - if isinstance(item, type) and issubclass(item, BaseContainer): - self._add_providers_from_containers([item], self._scope) - elif isinstance(item, ContextResource): - self._context_providers.add(item) + scope_providers: set[ContextResource[typing.Any]] = set() + self._add_providers_from_containers(BaseContainerMeta.get_instances().values(), scope_providers, scope) + for provider in scope_providers: + entered_context_items[provider] = provider + self._entered_context_items = tuple(entered_context_items) + return scope def _add_providers_from_containers( - self, containers: Iterable[ContainerType], scope: ContextScope | None = ContextScopes.ANY + self, + containers: Iterable[BaseContainerType], + target: set[ContextResource[typing.Any]], + scope: ContextScope | None = ContextScopes.ANY, ) -> None: for container in containers: for container_provider in container.get_providers().values(): if isinstance(container_provider, ContextResource): provider_scope = container_provider.get_scope() if provider_scope in (scope, ContextScopes.ANY): - self._context_providers.add(container_provider) + target.add(container_provider) @override def __enter__(self) -> ContextType: - self._resolve_initial_conditions() + scope = self._resolve_initial_conditions() self._context_stack = contextlib.ExitStack() - self._scope_token = _set_current_scope(self._scope) - for item in self._context_providers: + self._scope_token = _set_current_scope(scope) + for item in self._entered_context_items: if item.supports_context_sync(): self._context_stack.enter_context(item.context_sync()) return self._enter_globals() @override async def __aenter__(self) -> ContextType: - self._resolve_initial_conditions() + scope = self._resolve_initial_conditions() self._context_stack = contextlib.AsyncExitStack() - self._scope_token = _set_current_scope(self._scope) - for item in self._context_providers: + self._scope_token = _set_current_scope(scope) + for item in self._entered_context_items: await self._context_stack.enter_async_context(item.context_async()) return self._enter_globals() @@ -667,8 +769,7 @@ def __init__( Args: app (ASGIApp): The ASGI application to wrap. - *context_items (SupportsContext[Any]): A collection of containers and providers that - need context initialization prior to a request. + *context_items: Containers and providers that need context initialization prior to a request. global_context (dict[str, Any] | None): A global context dictionary to set before requests. scope (ContextScope | None): The scope in which the context should be initialized. diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index 6242ea5f..9102d0da 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -5,8 +5,15 @@ from typing_extensions import override -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import ( + AbstractProvider, + _resolve_arguments, + _resolve_arguments_sync, + _resolve_keyword_arguments, + _resolve_keyword_arguments_sync, +) from that_depends.providers.mixin import ProviderWithArguments +from that_depends.utils import is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -76,13 +83,23 @@ def build_resource(text: str, number: int): """ def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: raise NotImplementedError - __slots__ = "_args", "_factory", "_kwargs", "_override" + __slots__ = ( + "_args", + "_args_are_providers", + "_factory", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + ) def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: """Initialize a Factory instance. @@ -94,37 +111,34 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co]] = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) self._register_arguments() @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) return self._factory( - *[ # type: ignore[arg-type] - await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) @override def resolve_sync(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) return self._factory( - *[ # type: ignore[arg-type] - x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) @@ -147,13 +161,23 @@ async def async_build_resource(text: str): """ def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: raise NotImplementedError - __slots__ = "_args", "_factory", "_kwargs", "_override" + __slots__ = ( + "_args", + "_args_are_providers", + "_factory", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + ) @overload def __init__( @@ -176,22 +200,27 @@ def __init__( """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co | typing.Awaitable[T_co]]] = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) self._register_arguments() @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) - args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args] - kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()} + args = await _resolve_arguments(self._args, self._args_are_providers) + kwargs = await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers) result = self._factory( - *args, # type:ignore[arg-type] - **kwargs, # type:ignore[arg-type] + *args, + **kwargs, ) if inspect.isawaitable(result): diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index dd8f1ef0..20228f11 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -5,7 +5,14 @@ from typing_extensions import override from that_depends.providers import AbstractProvider +from that_depends.providers.base import ( + _resolve_arguments, + _resolve_arguments_sync, + _resolve_keyword_arguments, + _resolve_keyword_arguments_sync, +) from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown +from that_depends.utils import UNSET, Unset, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -50,67 +57,77 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co]] = factory self._thread_local = threading.local() self._asyncio_lock = asyncio.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @property - def _instance(self) -> T_co | None: - return getattr(self._thread_local, "instance", None) + def _instance(self) -> T_co | Unset: + return typing.cast(T_co | Unset, getattr(self._thread_local, "instance", UNSET)) @_instance.setter - def _instance(self, value: T_co | None) -> None: + def _instance(self, value: T_co | Unset) -> None: self._thread_local.instance = value @override async def resolve(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) + if is_set(self._instance): + return self._instance async with self._asyncio_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() - self._instance = self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + instance = self._factory( + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) - return self._instance + self._instance = instance + return instance @override def resolve_sync(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() - self._instance = self._factory( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + instance = self._factory( + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) - return self._instance + self._instance = instance + return instance @override def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) @@ -122,8 +139,8 @@ async def tear_down(self, propagate: bool = True) -> None: After calling this method, subsequent calls to `resolve_sync()` on the same thread will produce a new instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: await self._tear_down_children() diff --git a/that_depends/providers/mixin.py b/that_depends/providers/mixin.py index 5462230b..99597af0 100644 --- a/that_depends/providers/mixin.py +++ b/that_depends/providers/mixin.py @@ -26,6 +26,22 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> class ProviderWithArguments(abc.ABC): """Interface for providers that require arguments.""" + __slots__ = ("_arguments_registered",) + + def __init__(self) -> None: + """Initialize provider argument registration state.""" + super().__init__() + self._arguments_registered = False + + def _mark_arguments_registered(self) -> bool: + if self._arguments_registered: + return False + self._arguments_registered = True + return True + + def _reset_arguments_registration(self) -> None: + self._arguments_registered = False + @abc.abstractmethod def _register_arguments(self) -> None: """Register arguments for the provider.""" diff --git a/that_depends/providers/object.py b/that_depends/providers/object.py index b29ed7cd..c0843eae 100644 --- a/that_depends/providers/object.py +++ b/that_depends/providers/object.py @@ -3,6 +3,7 @@ from typing_extensions import override from that_depends.providers.base import AbstractProvider +from that_depends.utils import is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -42,6 +43,6 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) return self._obj diff --git a/that_depends/providers/selector.py b/that_depends/providers/selector.py index 5bbe023f..9ec2c48b 100644 --- a/that_depends/providers/selector.py +++ b/that_depends/providers/selector.py @@ -5,6 +5,7 @@ from typing_extensions import override from that_depends.providers.base import AbstractProvider +from that_depends.utils import is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -69,7 +70,7 @@ def my_selector(): @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) if isinstance(self._selector, AbstractProvider): @@ -83,7 +84,7 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) if isinstance(self._selector, AbstractProvider): diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index daa470d6..961f77ab 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -6,8 +6,15 @@ from typing_extensions import override -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import ( + AbstractProvider, + _resolve_arguments, + _resolve_arguments_sync, + _resolve_keyword_arguments, + _resolve_keyword_arguments_sync, +) from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown +from that_depends.utils import UNSET, Unset, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -34,7 +41,18 @@ def my_factory() -> float: """ - __slots__ = "_args", "_asyncio_lock", "_factory", "_instance", "_kwargs", "_override", "_threading_lock" + __slots__ = ( + "_args", + "_args_are_providers", + "_asyncio_lock", + "_factory", + "_instance", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + "_threading_lock", + ) def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: """Initialize the Singleton provider. @@ -46,54 +64,64 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() - self._factory: typing.Final = factory - self._instance: T_co | None = None + self._factory: typing.Final[typing.Callable[..., T_co]] = factory + self._instance: T_co | Unset = UNSET self._asyncio_lock: typing.Final = asyncio.Lock() self._threading_lock: typing.Final = threading.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @override async def resolve(self) -> T_co: - if self._override is not None: + if is_set(self._override): self._register_arguments() return typing.cast(T_co, self._override) + if is_set(self._instance): + return self._instance # lock to prevent resolving several times async with self._asyncio_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() self._instance = self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) return self._instance @override def resolve_sync(self) -> T_co: - if self._override is not None: + if is_set(self._override): self._register_arguments() return typing.cast(T_co, self._override) + if is_set(self._instance): + return self._instance # lock to prevent resolving several times with self._threading_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() self._instance = self._factory( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) return self._instance @@ -103,8 +131,8 @@ async def tear_down(self, propagate: bool = True) -> None: After calling this method, the next resolve() call will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: await self._tear_down_children() @@ -115,8 +143,8 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> After calling this method, the next resolve call will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) @@ -142,7 +170,17 @@ async def my_async_factory() -> float: """ - __slots__ = "_args", "_asyncio_lock", "_factory", "_instance", "_kwargs", "_override" + __slots__ = ( + "_args", + "_args_are_providers", + "_asyncio_lock", + "_factory", + "_instance", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + ) def __init__( self, @@ -159,37 +197,45 @@ def __init__( """ super().__init__() - self._factory: typing.Final[typing.Callable[P, typing.Awaitable[T_co]]] = factory - self._instance: T_co | None = None + self._factory: typing.Final[typing.Callable[..., typing.Awaitable[T_co]]] = factory + self._instance: T_co | Unset = UNSET self._asyncio_lock: typing.Final = asyncio.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @override async def resolve(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) + if is_set(self._instance): + return self._instance # lock to prevent resolving several times async with self._asyncio_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() self._instance = await self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) return self._instance @@ -204,8 +250,8 @@ async def tear_down(self, propagate: bool = True) -> None: After calling this method, the next call to ``resolve()`` will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: await self._tear_down_children() @@ -216,8 +262,8 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> After calling this method, the next call to ``resolve_sync()`` will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async)