diff --git a/Justfile b/Justfile
index 7bab81d5..1c25f065 100644
--- a/Justfile
+++ b/Justfile
@@ -8,12 +8,14 @@ install:
lint:
uv run ruff format
uv run ruff check --fix
- uv run mypy .
+ uv run mypy . --disable-error-code=unused-ignore
+ uv run pyrefly check --no-progress-bar
lint-ci:
uv run ruff format --check
uv run ruff check --no-fix
- uv run mypy .
+ uv run mypy . --disable-error-code=unused-ignore
+ uv run pyrefly check --no-progress-bar
test *args:
uv run --no-sync pytest {{ args }}
diff --git a/docs/migration/v4.md b/docs/migration/v4.md
new file mode 100644
index 00000000..65cb9321
--- /dev/null
+++ b/docs/migration/v4.md
@@ -0,0 +1,92 @@
+# Migrating from 3.\* to 4.\*
+
+## How to Read This Guide
+
+This guide is intended to help you migrate existing functionality from `that-depends` version `3.*` to `4.*`.
+The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase.
+
+This migration intentionally focuses only on the **simplest change needed to preserve existing behaviour**.
+
+If you want to learn more about the new internals introduced in `4.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases).
+
+---
+
+## Changes in the API
+
+### **Collection providers now return read-only container types**
+
+In `4.*`, collection providers no longer resolve to mutable built-in containers:
+
+- `providers.List(...)` now resolves to a `Sequence` implemented as a `tuple`
+- `providers.Dict(...)` now resolves to a `Mapping` implemented as a `mappingproxy`
+
+If your existing code only reads from these values, you likely do not need to change anything.
+
+If your existing code mutates the resolved value, the **simplest way to preserve the old behaviour** is to convert the result at the call site:
+
+```python
+items = list(MyContainer.items.resolve_sync())
+mapping = dict(MyContainer.mapping.resolve_sync())
+```
+
+---
+
+## Behaviour-Preserving Migration Examples
+
+### **`providers.List(...)`**
+
+Previously in `3.*`, code like this returned a `list`:
+
+```python
+items = MyContainer.items.resolve_sync()
+items.append("new-item")
+```
+
+In `4.*`, `items` is a read-only sequence, so mutating it directly will no longer work.
+
+To preserve the previous behaviour, change it to:
+
+```python
+items = list(MyContainer.items.resolve_sync())
+items.append("new-item")
+```
+
+The same applies to async resolution:
+
+```python
+items = list(await MyContainer.items.resolve())
+items.append("new-item")
+```
+
+---
+
+### **`providers.Dict(...)`**
+
+Previously in `3.*`, code like this returned a mutable `dict`:
+
+```python
+mapping = MyContainer.mapping.resolve_sync()
+mapping["extra"] = "value"
+```
+
+In `4.*`, `mapping` is a read-only mapping, so item assignment will no longer work.
+
+To preserve the previous behaviour, change it to:
+
+```python
+mapping = dict(MyContainer.mapping.resolve_sync())
+mapping["extra"] = "value"
+```
+
+The same applies to async resolution:
+
+```python
+mapping = dict(await MyContainer.mapping.resolve())
+mapping["extra"] = "value"
+```
+
+---
+
+## Further Help
+
+If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues).
diff --git a/docs/overrides/main.html b/docs/overrides/main.html
index 7496f3a4..1f99571a 100644
--- a/docs/overrides/main.html
+++ b/docs/overrides/main.html
@@ -1,5 +1,5 @@
{% extends "base.html" %}
{% block announce %}
- that-depends 3.0 just released! If you are upgrading, check out the migration guide.
+ that-depends 4.0 just released! If you are upgrading, check out the migration guide.
{% endblock %}
diff --git a/docs/providers/collections.md b/docs/providers/collections.md
index 34294ba9..86154416 100644
--- a/docs/providers/collections.md
+++ b/docs/providers/collections.md
@@ -3,7 +3,7 @@ There are several collection providers: `List` and `Dict`
## List
- List provider contains other providers.
-- Resolves into list of dependencies.
+- Resolves into an immutable sequence of dependencies.
```python
import random
@@ -16,12 +16,12 @@ class DIContainer(BaseContainer):
DIContainer.numbers_sequence.resolve_sync()
-# [0.3035656170071561, 0.8280498192037787]
+# (0.3035656170071561, 0.8280498192037787)
```
## Dict
- Dict provider is a collection of named providers.
-- Resolves into dict of dependencies.
+- Resolves into a read-only mapping of dependencies.
```python
import random
@@ -34,5 +34,5 @@ class DIContainer(BaseContainer):
DIContainer.numbers_map.resolve_sync()
-# {'key1': 0.6851384528299208, 'key2': 0.41044920948045294}
+# mappingproxy({'key1': 0.6851384528299208, 'key2': 0.41044920948045294})
```
diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md
index 8aaa327e..87e00c97 100644
--- a/docs/providers/context-resources.md
+++ b/docs/providers/context-resources.md
@@ -8,7 +8,7 @@
To interact with both types of contexts, there are two separate interfaces:
1. Use the `container_context()` context manager to interact with the global context and manage `ContextResource` providers.
-2. Directly manage a `ContextResource` context by using the `SupportsContext` interface, which both containers
+2. Directly manage a `ContextResource` context by using the `SupportsContext` protocol, which both containers
and `ContextResource` providers implement.
---
@@ -185,7 +185,7 @@ async with container_context(MyContainer.async_resource):
...
```
-It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` interface described
+It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` protocol described
[here](#quick-reference).
### Context Hierarchy
diff --git a/examples/benchmark/RESULTS.md b/examples/benchmark/RESULTS.md
index 90ff8dd0..ef211fbc 100644
--- a/examples/benchmark/RESULTS.md
+++ b/examples/benchmark/RESULTS.md
@@ -3,6 +3,8 @@ Based on this [benchmark](injection.py):
| version / iterations | 10^4 | 10^5 | 10^6 |
|----------------------|--------|--------|---------|
+| 4.0 | 0.0399 | 0.3950 | 3.8107 |
+| 3.9.2 | 0.0920 | 0.9804 | 9.1009 |
| 3.2.0 | 0.1039 | 1.0563 | 10.3576 |
| 3.0.0.a2 | 0.0870 | 0.9430 | 8.8815 |
| 3.0.0.a1 | 0.1399 | 1.4136 | 14.1829 |
diff --git a/mkdocs.yml b/mkdocs.yml
index 67e1bb74..416513c8 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -37,6 +37,7 @@ nav:
- Migration:
- 1.* to 2.*: migration/v2.md
- 2.* to 3.*: migration/v3.md
+ - 3.* to 4.*: migration/v4.md
- Development:
- Contributing: dev/contributing.md
- Decisions: dev/main-decisions.md
diff --git a/pyproject.toml b/pyproject.toml
index 2b9359eb..077ab986 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,7 +47,7 @@ dev = [
"mkdocs>=1.6.1",
"pytest-randomly",
"mkdocs-llmstxt>=0.4.0",
- "pydantic<2.12.4" # causes tests to fail because of: TypeError: _eval_type() got an unexpected keyword argument 'prefer_fwd_module'
+ "pyrefly>=0.61.1",
]
[build-system]
@@ -62,6 +62,7 @@ module-root = ""
python_version = "3.10"
strict = true
+
[tool.ruff]
fix = true
unsafe-fixes = true
diff --git a/tests/experimental/test_container_2.py b/tests/experimental/test_container_2.py
index f5a9ec46..30b11561 100644
--- a/tests/experimental/test_container_2.py
+++ b/tests/experimental/test_container_2.py
@@ -22,6 +22,12 @@ def __hash__(self) -> int:
return 0 # pragma: nocover
+class _BrokenContextObject(providers.Object[int]):
+ def get_scope(self) -> ContextScopes | None:
+ msg = "_missing_scope"
+ raise AttributeError(msg)
+
+
async def _async_creator() -> AsyncIterator[float]:
yield random.random()
@@ -30,6 +36,9 @@ def _sync_creator() -> Iterator[_RandomWrapper]:
yield _RandomWrapper()
+broken_context_provider = _BrokenContextObject(1)
+
+
class Container2(BaseContainer):
"""Test Container 2."""
@@ -139,6 +148,13 @@ async def test_lazy_provider_not_implemented() -> None:
await lazy_provider.tear_down()
+def test_lazy_provider_not_implemented_when_context_method_raises_attribute_error() -> None:
+ lazy_provider = LazyProvider("tests.experimental.test_container_2.broken_context_provider")
+
+ with pytest.raises(NotImplementedError):
+ lazy_provider.get_scope()
+
+
def test_lazy_provider_attr_getter() -> None:
lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider")
with lazy_provider.context_sync(force=True):
diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py
index b1292c63..c8ea886b 100644
--- a/tests/providers/test_attr_getter.py
+++ b/tests/providers/test_attr_getter.py
@@ -171,6 +171,7 @@ class _Container(BaseContainer):
attr_getter = _Container.child.v
+ attr_getter._register_arguments()
attr_getter._register_arguments()
assert attr_getter.resolve_sync() == _Container.parent.resolve_sync()
diff --git a/tests/providers/test_base.py b/tests/providers/test_base.py
index 1c8edbb2..7a0601cf 100644
--- a/tests/providers/test_base.py
+++ b/tests/providers/test_base.py
@@ -6,11 +6,12 @@
from typing_extensions import override
from that_depends import BaseContainer
-from that_depends.providers.base import AbstractProvider
+from that_depends.providers.base import AbstractProvider, _resolve_arguments, _resolve_arguments_sync
from that_depends.providers.local_singleton import ThreadLocalSingleton
from that_depends.providers.mixin import SupportsTeardown
from that_depends.providers.resources import Resource
from that_depends.providers.singleton import AsyncSingleton, Singleton
+from that_depends.utils import is_set
class DummyProvider(SupportsTeardown, AbstractProvider[int]):
@@ -43,6 +44,18 @@ def resolve_sync(self) -> int:
return self._instance # pragma: no cover
+async def test_resolve_arguments_with_multiple_values() -> None:
+ provider = DummyProvider()
+
+ assert await _resolve_arguments((provider, "value"), (True, False)) == [1, "value"]
+
+
+def test_resolve_arguments_sync_with_multiple_values() -> None:
+ provider = DummyProvider()
+
+ assert _resolve_arguments_sync((provider, "value"), (True, False)) == [1, "value"]
+
+
def test_add_child_provider() -> None:
provider_a = DummyProvider()
provider_b = DummyProvider()
@@ -75,6 +88,28 @@ def test_register_with_mixed_items() -> None:
assert parent in child_1._children, "Expected child_1._children to contain parent"
+def test_invalidate_scope_init_order_handles_duplicate_descendants() -> None:
+ root = DummyProvider()
+ left = DummyProvider()
+ right = DummyProvider()
+ shared = DummyProvider()
+
+ root.add_child_provider(left)
+ root.add_child_provider(right)
+ left.add_child_provider(shared)
+ right.add_child_provider(shared)
+
+ for provider in (root, left, right, shared):
+ provider._scope_context_init_order = ()
+ provider._scope_init_order = ()
+
+ root._invalidate_scope_init_order()
+
+ for provider in (root, left, right, shared):
+ assert provider._scope_context_init_order is None
+ assert provider._scope_init_order is None
+
+
def test_sync_tear_down_propagation() -> None:
parent = DummyProvider()
child_1 = DummyProvider()
@@ -168,6 +203,16 @@ def test_thread_local_singleton_registration_and_deregistration(dummy_singleton:
assert thread_local not in dummy_singleton._children, "ThreadLocalSingleton should be deregistered after teardown."
+def test_get_scope_init_order_is_cached_and_includes_parents(dummy_singleton: Singleton[int]) -> None:
+ singleton = Singleton(lambda value: value + 1, dummy_singleton.cast)
+
+ first = singleton._get_scope_init_order()
+ second = singleton._get_scope_init_order()
+
+ assert first == (dummy_singleton, singleton)
+ assert second is first
+
+
def test_resource_registration_and_deregistration(dummy_singleton: Singleton[int]) -> None:
resource = Resource(_resource_generator, dummy_singleton.cast)
@@ -233,7 +278,7 @@ def test_propagate_off() -> None:
parent.tear_down_sync(propagate=False)
assert child in parent._children
- assert child._instance is not None
+ assert is_set(child._instance)
async def test_async_tear_down_propagation_with_singleton() -> None:
@@ -244,7 +289,7 @@ async def test_async_tear_down_propagation_with_singleton() -> None:
await parent.tear_down()
- assert child._instance is None
+ assert not is_set(child._instance)
async def test_async_propagate_off() -> None:
@@ -255,7 +300,7 @@ async def test_async_propagate_off() -> None:
await parent.tear_down(propagate=False)
- assert child._instance is not None
+ assert is_set(child._instance)
async def test_provider_registration_in_different_scope_async() -> None:
diff --git a/tests/providers/test_collections.py b/tests/providers/test_collections.py
index 14a58ae3..5d028a35 100644
--- a/tests/providers/test_collections.py
+++ b/tests/providers/test_collections.py
@@ -28,7 +28,9 @@ async def test_list_provider() -> None:
sync_resource = await DIContainer.sync_resource()
async_resource = await DIContainer.async_resource()
- assert sequence == [sync_resource, async_resource]
+ assert sequence == (sync_resource, async_resource)
+ with pytest.raises(TypeError):
+ typing.cast(typing.Any, sequence)[0] = sync_resource
def test_list_failed_sync_resolve() -> None:
@@ -48,6 +50,8 @@ async def test_dict_provider() -> None:
assert mapping == {"sync_resource": sync_resource, "async_resource": async_resource}
assert mapping == DIContainer.mapping.resolve_sync()
+ with pytest.raises(TypeError):
+ typing.cast(typing.Any, mapping)["sync_resource"] = sync_resource
@pytest.mark.parametrize("provider", [DIContainer.sequence, DIContainer.mapping])
diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py
index 9420097f..02a39c8c 100644
--- a/tests/providers/test_context_resources.py
+++ b/tests/providers/test_context_resources.py
@@ -24,6 +24,7 @@
from that_depends.meta import DefaultScopeNotDefinedError
from that_depends.providers import DIContextMiddleware, container_context
from that_depends.providers.context_resources import InvalidContextError, _enter_named_scope, fetch_context_item_by_type
+from that_depends.utils import is_set
logger = logging.getLogger(__name__)
@@ -186,14 +187,15 @@ async def test_early_exit_of_container_context() -> None:
async def test_resource_context_early_teardown() -> None:
context: ResourceContext[str] = ResourceContext(is_async=True)
- assert context.context_stack is None
+ assert context.is_async is True
+ assert not is_set(context.context_stack)
context.tear_down_sync()
- assert context.context_stack is None
+ assert not is_set(context.context_stack)
async def test_teardown_sync_container_context_with_async_resource() -> None:
resource_context: ResourceContext[typing.Any] = ResourceContext(is_async=True)
- resource_context.context_stack = AsyncExitStack()
+ resource_context.set_context_state(context_stack=AsyncExitStack())
message = "Cannot tear down async context in sync mode"
with pytest.raises(RuntimeError, match=message):
resource_context.tear_down_sync()
@@ -563,6 +565,56 @@ def _explicit_injected(val: str = Provide[sync_context_resource]) -> str:
assert _explicit_injected() != _explicit_injected()
+def test_sync_context_resource_accepts_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ current_scope = ContextScope("MATCHING_SCOPE")
+ resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope)
+
+ with _enter_named_scope(current_scope), resource.context_sync():
+ assert isinstance(resource.resolve_sync(), str)
+
+
+async def test_async_context_resource_accepts_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ current_scope = ContextScope("MATCHING_SCOPE")
+ resource = providers.ContextResource(create_async_context_resource).with_config(configured_scope)
+
+ with _enter_named_scope(current_scope):
+ async with resource.context_async():
+ assert isinstance(await resource.resolve(), str)
+
+
+async def test_async_context_resource_strict_scope_accepts_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ current_scope = ContextScope("MATCHING_SCOPE")
+ resource = providers.ContextResource(create_async_context_resource).with_config(configured_scope, strict_scope=True)
+
+ with _enter_named_scope(current_scope):
+ async with resource.context_async(force=True):
+ assert isinstance(await resource.resolve(), str)
+
+
+def test_sync_context_resource_strict_scope_accepts_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ current_scope = ContextScope("MATCHING_SCOPE")
+ resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope, strict_scope=True)
+
+ with _enter_named_scope(current_scope), resource.context_sync(force=True):
+ assert isinstance(resource.resolve_sync(), str)
+
+
+def test_container_context_accepts_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ current_scope = ContextScope("MATCHING_SCOPE")
+
+ class _Container(BaseContainer):
+ default_scope = ContextScopes.ANY
+ sync_context_resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope)
+
+ with container_context(_Container, scope=current_scope):
+ assert isinstance(_Container.sync_context_resource.resolve_sync(), str)
+
+
async def test_async_context_resource_with_dependent_container() -> None:
"""Container should initialize async context resource for dependent containers."""
async with DIContainer.context_async():
@@ -590,9 +642,18 @@ def test_enter_sync_context_for_async_resource_should_throw(
async_context_resource._enter_context_sync()
-def test_exit_sync_context_before_enter_should_throw(sync_context_resource: providers.ContextResource[str]) -> None:
- with pytest.raises(RuntimeError):
- sync_context_resource._exit_context_sync()
+def test_exit_sync_context_before_enter_should_throw() -> None:
+ provider = providers.ContextResource(create_sync_context_resource)
+
+ with pytest.raises(RuntimeError, match=r"Context is not set, call ``_enter_sync_context`` first"):
+ provider._exit_context_sync()
+
+
+def test_context_sync_manager_exit_before_enter_should_throw(
+ sync_context_resource: providers.ContextResource[str],
+) -> None:
+ with pytest.raises(RuntimeError, match=r"Context is not set, call ``__enter__`` first"):
+ sync_context_resource.context_sync().__exit__(None, None, None)
async def test_exit_async_context_before_enter_should_throw(
@@ -609,6 +670,18 @@ def test_enter_sync_context_from_async_resource_should_throw(
stack.enter_context(async_context_resource.context_sync())
+def test_enter_and_exit_sync_context_directly(sync_context_resource: providers.ContextResource[str]) -> None:
+ context = sync_context_resource._enter_context_sync()
+
+ assert isinstance(context, ResourceContext)
+ assert sync_context_resource.resolve_sync() is not None
+
+ sync_context_resource._exit_context_sync()
+
+ with pytest.raises(RuntimeError, match=r"Context is not set. Use container_context"):
+ sync_context_resource.resolve_sync()
+
+
async def test_preserve_globals_and_initial_context() -> None:
initial_context = {"test_1": "test_1", "test_2": "test_2"}
@@ -699,6 +772,11 @@ class _Container(BaseContainer): ...
assert _Container.get_scope() is ContextScopes.ANY
+ class _UnsetScopeContainer(BaseContainer):
+ default_scope = None
+
+ assert _UnsetScopeContainer.get_scope() is ContextScopes.ANY
+
class _ScopedContainer(BaseContainer):
default_scope = ContextScopes.INJECT
diff --git a/tests/providers/test_local_singleton.py b/tests/providers/test_local_singleton.py
index a614c958..4f92ddae 100644
--- a/tests/providers/test_local_singleton.py
+++ b/tests/providers/test_local_singleton.py
@@ -4,10 +4,12 @@
import time
import typing
from concurrent.futures.thread import ThreadPoolExecutor
+from unittest.mock import Mock
import pytest
from that_depends.providers import AsyncFactory, ThreadLocalSingleton
+from that_depends.utils import is_set
random.seed(23)
@@ -34,7 +36,7 @@ def test_thread_local_singleton_same_thread() -> None:
provider.tear_down_sync()
- assert provider._instance is None, "Tear down failed: Instance should be reset to None."
+ assert not is_set(provider._instance), "Tear down failed: Instance should be unset."
async def test_async_thread_local_singleton_asyncio() -> None:
@@ -48,7 +50,23 @@ async def test_async_thread_local_singleton_asyncio() -> None:
await provider.tear_down()
- assert provider._instance is None, "Tear down failed: Instance should be reset to None."
+ assert not is_set(provider._instance), "Tear down failed: Instance should be unset."
+
+
+async def test_thread_local_singleton_reuses_instance_created_while_waiting_on_lock() -> None:
+ expected_value = 42
+ factory = Mock(return_value=1)
+ provider = ThreadLocalSingleton(factory)
+
+ await provider._asyncio_lock.acquire()
+ task = asyncio.create_task(provider.resolve())
+ await asyncio.sleep(0)
+
+ provider._instance = expected_value
+ provider._asyncio_lock.release()
+
+ assert await task == expected_value
+ factory.assert_not_called()
def test_thread_local_singleton_different_threads() -> None:
diff --git a/tests/providers/test_resources.py b/tests/providers/test_resources.py
index a48fc95e..239cd6d3 100644
--- a/tests/providers/test_resources.py
+++ b/tests/providers/test_resources.py
@@ -14,6 +14,7 @@
create_sync_resource,
)
from that_depends import BaseContainer, providers
+from that_depends.utils import is_set
logger = logging.getLogger(__name__)
@@ -189,9 +190,9 @@ def create_resource() -> typing.Iterator[str]:
def test_sync_resource_sync_tear_down() -> None:
DIContainer.sync_resource.resolve_sync()
- assert DIContainer.sync_resource._context.instance is not None
+ assert is_set(DIContainer.sync_resource._context.instance)
DIContainer.sync_resource.tear_down_sync()
- assert DIContainer.sync_resource._context.instance is None
+ assert not is_set(DIContainer.sync_resource._context.instance)
async def test_async_resource_sync_tear_down_raises() -> None:
diff --git a/tests/providers/test_singleton.py b/tests/providers/test_singleton.py
index 7bba7bd0..3be5b911 100644
--- a/tests/providers/test_singleton.py
+++ b/tests/providers/test_singleton.py
@@ -9,6 +9,7 @@
import pytest
from that_depends import BaseContainer, providers
+from that_depends.utils import is_set
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
@@ -144,6 +145,16 @@ async def test_async_singleton_override() -> None:
assert result == SingletonFactory(dep1="bar")
+def test_async_singleton_register_arguments_is_idempotent() -> None:
+ parent = providers.Singleton(lambda: "foo")
+ singleton_async = providers.AsyncSingleton(create_async_obj, value=parent.cast)
+
+ singleton_async._register_arguments()
+ singleton_async._register_arguments()
+
+ assert singleton_async in parent._children
+
+
async def test_async_singleton_asyncio_concurrency() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
@@ -187,9 +198,9 @@ async def test_async_singleton_teardown() -> None:
await singleton_async.resolve()
singleton_async.tear_down_sync()
- assert singleton_async._instance is None
+ assert not is_set(singleton_async._instance)
await singleton_async.resolve()
await singleton_async.tear_down()
- assert singleton_async._instance is None
+ assert not is_set(singleton_async._instance)
diff --git a/tests/test_container.py b/tests/test_container.py
index f7f4752d..fcc5ac09 100644
--- a/tests/test_container.py
+++ b/tests/test_container.py
@@ -3,6 +3,7 @@
from tests.container import DIContainer
from that_depends import BaseContainer, providers
+from that_depends.utils import is_set
def _sync_resource() -> typing.Iterator[float]:
@@ -15,11 +16,11 @@ async def test_container_sync_teardown() -> None:
for provider in DIContainer.providers.values():
if isinstance(provider, providers.Resource):
if provider._is_async:
- assert provider._context.instance is not None
+ assert is_set(provider._context.instance)
else:
- assert provider._context.instance is None
+ assert not is_set(provider._context.instance)
if isinstance(provider, providers.Singleton):
- assert provider._instance is None
+ assert not is_set(provider._instance)
async def test_container_tear_down() -> None:
@@ -27,9 +28,9 @@ async def test_container_tear_down() -> None:
await DIContainer.tear_down()
for provider in DIContainer.providers.values():
if isinstance(provider, providers.Resource):
- assert provider._context.instance is None
+ assert not is_set(provider._context.instance)
if isinstance(provider, providers.Singleton):
- assert provider._instance is None
+ assert not is_set(provider._instance)
async def test_container_sync_tear_down_propagation() -> None:
@@ -46,8 +47,8 @@ class _DependentContainer(BaseContainer):
DIContainer.tear_down_sync()
- assert _DependentContainer.singleton._instance is None
- assert _DependentContainer.resource._context.instance is None
+ assert not is_set(_DependentContainer.singleton._instance)
+ assert not is_set(_DependentContainer.resource._context.instance)
async def test_container_tear_down_propagation() -> None:
@@ -74,7 +75,7 @@ class _DependentContainer(BaseContainer):
await DIContainer.tear_down()
- assert _DependentContainer.async_singleton._instance is None
- assert _DependentContainer.async_resource._context.instance is None
- assert _DependentContainer.sync_singleton._instance is None
- assert _DependentContainer.sync_resource._context.instance is None
+ assert not is_set(_DependentContainer.async_singleton._instance)
+ assert not is_set(_DependentContainer.async_resource._context.instance)
+ assert not is_set(_DependentContainer.sync_singleton._instance)
+ assert not is_set(_DependentContainer.sync_resource._context.instance)
diff --git a/tests/test_injection.py b/tests/test_injection.py
index 448f4ecf..ac4f3d5b 100644
--- a/tests/test_injection.py
+++ b/tests/test_injection.py
@@ -9,8 +9,24 @@
import pytest
from tests import container
-from that_depends import BaseContainer, ContextScopes, Provide, container_context, get_current_scope, inject, providers
-from that_depends.injection import ContextProviderError, StringProviderDefinition
+from that_depends import (
+ BaseContainer,
+ ContextScope,
+ ContextScopes,
+ Provide,
+ container_context,
+ get_current_scope,
+ inject,
+ providers,
+)
+from that_depends.injection import (
+ ContextProviderError,
+ StringProviderDefinition,
+ _build_injection_plan,
+ _DirectInjectionParameter,
+ _SyncInjectionStack,
+ _TypedInjectionParameter,
+)
@pytest.fixture(name="fixture_one")
@@ -55,6 +71,61 @@ async def inner(
await inner(True, arg2=container.SimpleFactory(dep1="1", dep2=2))
+def test_sync_injection_stack_closes_entered_context_managers() -> None:
+ events: list[str] = []
+
+ @contextmanager
+ def _managed() -> typing.Iterator[str]:
+ events.append("enter")
+ try:
+ yield "value"
+ finally:
+ events.append("exit")
+
+ with _SyncInjectionStack() as stack:
+ assert stack.enter_context(_managed()) == "value"
+ events.append("body")
+
+ assert events == ["enter", "body", "exit"]
+
+
+def test_build_injection_plan_stores_direct_provider_separately() -> None:
+ provider = providers.Object(1)
+
+ def _injected(value: providers.Object[int] = provider) -> providers.Object[int]:
+ return value
+
+ plan = _build_injection_plan(_injected)
+
+ assert _injected(provider) is provider
+ assert plan.direct_parameters == (
+ _DirectInjectionParameter(0, "value", provider, provider._get_scope_context_init_order()),
+ )
+
+
+def test_build_injection_plan_stores_annotation_for_type_based_injection() -> None:
+ def _injected(value: float = Provide()) -> float:
+ return value
+
+ plan = _build_injection_plan(_injected)
+
+ assert _injected(1.0) == 1.0
+ assert plan.typed_parameters == (_TypedInjectionParameter(0, "value", float),)
+
+
+def test_build_injection_plan_stores_string_provider_separately() -> None:
+ def _injected(value: int = Provide["Container.provider"]) -> int:
+ return value
+
+ plan = _build_injection_plan(_injected)
+
+ assert _injected(1) == 1
+ assert len(plan.string_parameters) == 1
+ assert plan.string_parameters[0].argument_index == 0
+ assert plan.string_parameters[0].field_name == "value"
+ assert plan.string_parameters[0].definition._definition == "Container.provider"
+
+
async def test_empty_injection() -> None:
@inject
async def inner(_: int) -> None:
@@ -116,6 +187,15 @@ def inner_gen(_: int) -> typing.Generator[int, None, None]:
next(inner_gen(1))
+def test_sync_empty_injection_without_scope_warns() -> None:
+ @inject(scope=None)
+ def inner(value: int) -> int:
+ return value
+
+ with pytest.warns(RuntimeWarning, match=r"Expected injection, but nothing found. Remove @inject decorator."):
+ assert inner(1) == 1
+
+
def test_type_check() -> None:
@inject
async def main(simple_factory: container.SimpleFactory = Provide[container.DIContainer.simple_factory]) -> None:
@@ -140,6 +220,20 @@ async def _injected(val: int = Provide[_Container.async_resource]) -> int:
await inject(scope=ContextScopes.REQUEST)(_injected)()
+async def test_async_injection_with_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ injection_scope = ContextScope("MATCHING_SCOPE")
+
+ class _Container(BaseContainer):
+ default_scope = ContextScopes.ANY
+ async_resource = providers.ContextResource(_async_creator).with_config(scope=configured_scope)
+
+ async def _injected(val: int = Provide[_Container.async_resource]) -> int:
+ return val
+
+ assert await inject(scope=injection_scope)(_injected)() == 1
+
+
async def test_sync_injection_with_scope() -> None:
class _Container(BaseContainer):
default_scope = ContextScopes.ANY
@@ -156,6 +250,20 @@ def _injected(val: int = Provide[_Container.p_inject]) -> int:
inject(scope=ContextScopes.REQUEST)(_injected)()
+def test_sync_injection_with_equal_scope_instance() -> None:
+ configured_scope = ContextScope("MATCHING_SCOPE")
+ injection_scope = ContextScope("MATCHING_SCOPE")
+
+ class _Container(BaseContainer):
+ default_scope = ContextScopes.ANY
+ p_inject = providers.ContextResource(_sync_creator).with_config(scope=configured_scope)
+
+ def _injected(val: int = Provide[_Container.p_inject]) -> int:
+ return val
+
+ assert inject(scope=injection_scope)(_injected)() == 1
+
+
def test_inject_decorator_should_not_allow_any_scope() -> None:
with pytest.raises(ValueError, match=f"{ContextScopes.ANY} is not allowed in inject decorator."):
inject(scope=ContextScopes.ANY)
@@ -220,6 +328,34 @@ def _injected(val: int = Provide["_Container.sync_resource"]) -> int:
assert _injected() == return_value
+async def test_async_injection_with_string_provider_definition_respects_explicit_arguments() -> None:
+ override_value = 10
+
+ class _Container(BaseContainer):
+ async_resource = providers.AsyncFactory(_async_creator)
+
+ @inject
+ async def _injected(val: int = Provide["_Container.async_resource"]) -> int:
+ return val
+
+ assert await _injected(override_value) == override_value
+ assert await _injected(val=override_value) == override_value
+
+
+def test_sync_injection_with_string_provider_definition_respects_explicit_arguments() -> None:
+ override_value = 10
+
+ class _Container(BaseContainer):
+ sync_resource = providers.Factory(lambda: 1)
+
+ @inject
+ def _injected(val: int = Provide["_Container.sync_resource"]) -> int:
+ return val
+
+ assert _injected(override_value) == override_value
+ assert _injected(val=override_value) == override_value
+
+
def test_provider_string_definition_with_alias() -> None:
return_value = 321
@@ -336,6 +472,46 @@ def _injected(v: float = Provide[_Container.provider_used]) -> float:
assert _Container.provider_used.resolve_sync()
+async def test_injection_with_string_provider_definition_initializes_parent_context_async() -> None:
+ async def _async_resource() -> typing.AsyncIterator[float]:
+ yield random.random()
+
+ class _Container(BaseContainer):
+ provider_used = providers.ContextResource(_async_resource).with_config(scope=ContextScopes.INJECT)
+ dependent = providers.Factory(lambda value: value, provider_used.cast)
+
+ @inject
+ async def _injected(val: float = Provide["_Container.dependent"]) -> float:
+ assert val == await _Container.dependent.resolve()
+ assert val == await _Container.provider_used.resolve()
+ return val
+
+ assert isinstance(await _injected(), float)
+
+ with pytest.raises(RuntimeError):
+ await _Container.provider_used.resolve()
+
+
+def test_injection_with_string_provider_definition_initializes_parent_context_sync() -> None:
+ def _sync_resource() -> typing.Iterator[float]:
+ yield random.random()
+
+ class _Container(BaseContainer):
+ provider_used = providers.ContextResource(_sync_resource).with_config(scope=ContextScopes.INJECT)
+ dependent = providers.Factory(lambda value: value, provider_used.cast)
+
+ @inject
+ def _injected(val: float = Provide["_Container.dependent"]) -> float:
+ assert val == _Container.dependent.resolve_sync()
+ assert val == _Container.provider_used.resolve_sync()
+ return val
+
+ assert isinstance(_injected(), float)
+
+ with pytest.raises(RuntimeError):
+ _Container.provider_used.resolve_sync()
+
+
async def test_injection_initializes_context_for_parents_async() -> None:
async def _async_resource() -> typing.AsyncIterator[float]:
yield random.random()
@@ -863,6 +1039,20 @@ def _injected_2(val: float = Provide()) -> float:
assert isinstance(_injected_2(), float)
+def test_injection_by_type_sync_respects_explicit_arguments() -> None:
+ class _Container(BaseContainer):
+ sync_resource = providers.Factory(random.random).bind(float)
+
+ override_value = 10.0
+
+ @inject(container=_Container)
+ def _injected(val: float = Provide()) -> float:
+ return val
+
+ assert _injected(override_value) == override_value
+ assert _injected(val=override_value) == override_value
+
+
async def test_injection_by_type_async() -> None:
class _Container(BaseContainer):
sync_resource = providers.Factory(random.random).bind(float)
@@ -879,6 +1069,20 @@ async def _injected_2(val: float = Provide()) -> float:
assert isinstance(await _injected_2(), float)
+async def test_injection_by_type_async_respects_explicit_arguments() -> None:
+ class _Container(BaseContainer):
+ sync_resource = providers.Factory(random.random).bind(float)
+
+ override_value = 10.0
+
+ @inject(container=_Container)
+ async def _injected(val: float = Provide()) -> float:
+ return val
+
+ assert await _injected(override_value) == override_value
+ assert await _injected(val=override_value) == override_value
+
+
async def test_injection_by_type_async_generator() -> None:
class _Container(BaseContainer):
sync_resource = providers.Factory(random.random).bind(float)
diff --git a/tests/test_meta.py b/tests/test_meta.py
index b4efc149..6f1a9353 100644
--- a/tests/test_meta.py
+++ b/tests/test_meta.py
@@ -9,6 +9,13 @@ class _Test(metaclass=BaseContainerMeta):
assert _Test.get_scope() == ContextScopes.ANY
+def test_base_container_meta_uses_explicit_default_scope() -> None:
+ class _Test(metaclass=BaseContainerMeta):
+ default_scope = ContextScopes.INJECT
+
+ assert _Test.get_scope() == ContextScopes.INJECT
+
+
def test_base_container_meta_has_correct_default_name() -> None:
class _Test(metaclass=BaseContainerMeta):
pass
diff --git a/tests/test_multiple_containers.py b/tests/test_multiple_containers.py
index 6e74282a..69a2246f 100644
--- a/tests/test_multiple_containers.py
+++ b/tests/test_multiple_containers.py
@@ -4,6 +4,7 @@
from tests import container
from that_depends import BaseContainer, providers
+from that_depends.utils import is_set
class InnerContainer(BaseContainer):
@@ -23,16 +24,16 @@ async def test_included_container() -> None:
assert all(isinstance(x, datetime.datetime) for x in sequence)
await OuterContainer.tear_down()
- assert InnerContainer.sync_resource._context.instance is None
- assert InnerContainer.async_resource._context.instance is None
+ assert not is_set(InnerContainer.sync_resource._context.instance)
+ assert not is_set(InnerContainer.async_resource._context.instance)
await OuterContainer.init_resources()
sync_resource_context = InnerContainer.sync_resource._context
assert sync_resource_context
- assert sync_resource_context.instance is not None
+ assert is_set(sync_resource_context.instance)
async_resource_context = InnerContainer.async_resource._context
assert async_resource_context
- assert async_resource_context.instance is not None
+ assert is_set(async_resource_context.instance)
await OuterContainer.tear_down()
diff --git a/that_depends/container.py b/that_depends/container.py
index f42d08a2..d23afe40 100644
--- a/that_depends/container.py
+++ b/that_depends/container.py
@@ -1,13 +1,13 @@
import inspect
import typing
-from contextlib import contextmanager
+from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from typing import overload
from typing_extensions import override
from that_depends.meta import BaseContainerMeta
from that_depends.providers import AbstractProvider, AsyncSingleton, Resource, Singleton
-from that_depends.providers.context_resources import ContextScope, ContextScopes
+from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes
if typing.TYPE_CHECKING:
@@ -26,6 +26,42 @@ class BaseContainer(metaclass=BaseContainerMeta):
containers: list[type["BaseContainer"]]
default_scope: ContextScope | None = ContextScopes.ANY
+ @classmethod
+ def get_scope(cls) -> ContextScope | None:
+ """Return the default scope used by the container."""
+ if cls.default_scope is not None:
+ return cls.default_scope
+ return ContextScopes.ANY
+
+ @classmethod
+ @asynccontextmanager
+ async def context_async(cls, force: bool = False) -> typing.AsyncIterator[None]:
+ """Enter async contexts for all connected containers and context resources."""
+ async with AsyncExitStack() as stack:
+ for container in cls.get_containers():
+ await stack.enter_async_context(container.context_async(force=force))
+ for provider in cls.get_providers().values():
+ if isinstance(provider, ContextResource):
+ await stack.enter_async_context(provider.context_async(force=force))
+ yield
+
+ @classmethod
+ @contextmanager
+ def context_sync(cls, force: bool = False) -> typing.Iterator[None]:
+ """Enter sync contexts for all connected containers and sync context resources."""
+ with ExitStack() as stack:
+ for container in cls.get_containers():
+ stack.enter_context(container.context_sync(force=force))
+ for provider in cls.get_providers().values():
+ if isinstance(provider, ContextResource) and not provider._is_async: # noqa: SLF001
+ stack.enter_context(provider.context_sync(force=force))
+ yield
+
+ @classmethod
+ def supports_context_sync(cls) -> bool:
+ """Indicate that container classes support sync context management."""
+ return True
+
@classmethod
@overload
def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ...
diff --git a/that_depends/entities/resource_context.py b/that_depends/entities/resource_context.py
index 8e10b6e4..725c0a94 100644
--- a/that_depends/entities/resource_context.py
+++ b/that_depends/entities/resource_context.py
@@ -7,6 +7,7 @@
from typing_extensions import override
from that_depends.providers.mixin import CannotTearDownSyncError, SupportsTeardown
+from that_depends.utils import UNSET, Unset, is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -26,22 +27,41 @@ def __init__(self, is_async: bool) -> None:
For example within a ``async with container_context(Container): ...`` statement.
"""
- self.instance: T_co | None = None
+ self.instance: T_co | Unset = UNSET
self.asyncio_lock: typing.Final = asyncio.Lock()
self.threading_lock: typing.Final = threading.Lock()
- self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None
+ self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | Unset = UNSET
self.is_async = is_async
+ def set_context_state(
+ self,
+ *,
+ instance: T_co | Unset = UNSET,
+ context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | Unset = UNSET,
+ ) -> None:
+ """Set the context state of the resource.
+
+ Args:
+ instance: instance.
+ context_stack: stack.
+
+ Returns:
+ None
+
+ """
+ self.instance = instance
+ self.context_stack = context_stack
+
@staticmethod
def is_context_stack_async(
- context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None,
+ context_stack: object,
) -> typing.TypeGuard[contextlib.AsyncExitStack]:
"""Check if the context stack is an async context stack."""
return isinstance(context_stack, contextlib.AsyncExitStack)
@staticmethod
def is_context_stack_sync(
- context_stack: contextlib.AsyncExitStack | contextlib.ExitStack,
+ context_stack: object,
) -> typing.TypeGuard[contextlib.ExitStack]:
"""Check if the context stack is a sync context stack."""
return isinstance(context_stack, contextlib.ExitStack)
@@ -49,27 +69,27 @@ def is_context_stack_sync(
@override
async def tear_down(self, propagate: bool = True) -> None:
"""Tear down the async context stack."""
- if self.context_stack is None:
+ context_stack = self.context_stack
+ if not is_set(context_stack):
return
- if self.is_context_stack_async(self.context_stack):
- await self.context_stack.aclose()
- elif self.is_context_stack_sync(self.context_stack):
- self.context_stack.close()
- self.context_stack = None
- self.instance = None
+ if self.is_context_stack_async(context_stack):
+ await context_stack.aclose()
+ elif self.is_context_stack_sync(context_stack):
+ context_stack.close()
+ self.set_context_state(instance=UNSET, context_stack=UNSET)
@override
def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
"""Tear down the sync context stack."""
- if self.context_stack is None:
+ context_stack = self.context_stack
+ if not is_set(context_stack):
return
- if self.is_context_stack_sync(self.context_stack):
- self.context_stack.close()
- self.context_stack = None
- self.instance = None
- elif self.is_context_stack_async(self.context_stack):
+ if self.is_context_stack_sync(context_stack):
+ context_stack.close()
+ self.set_context_state(instance=UNSET, context_stack=UNSET)
+ elif self.is_context_stack_async(context_stack):
msg = "Cannot tear down async context in sync mode"
if raise_on_async:
raise CannotTearDownSyncError(msg)
diff --git a/that_depends/experimental/providers.py b/that_depends/experimental/providers.py
index c79becf6..6a5b1ec5 100644
--- a/that_depends/experimental/providers.py
+++ b/that_depends/experimental/providers.py
@@ -7,7 +7,7 @@
from that_depends import ContextScope
from that_depends.providers import AbstractProvider
-from that_depends.providers.context_resources import CT, SupportsContext
+from that_depends.providers.context_resources import CT_co, SupportsContext
from that_depends.providers.mixin import SupportsTeardown
@@ -55,32 +55,19 @@ def __init__(
@typing_extensions.override
def get_scope(self) -> ContextScope | None:
- provider = self._get_provider()
- if isinstance(provider, SupportsContext):
- return provider.get_scope()
- msg = "Underlying provider does not support context scopes"
- raise NotImplementedError(msg)
+ return typing.cast(ContextScope | None, self._call_context_method("get_scope"))
@typing_extensions.override
- def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]:
- provider = self._get_provider()
- if isinstance(provider, SupportsContext):
- return provider.context_async(force)
- msg = "Underlying provider does not support context management"
- raise NotImplementedError(msg)
+ def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT_co]:
+ return typing.cast(typing.AsyncContextManager[CT_co], self._call_context_method("context_async", force))
@typing_extensions.override
- def context_sync(self, force: bool = False) -> typing.ContextManager[CT]:
- provider = self._get_provider()
- if isinstance(provider, SupportsContext):
- return provider.context_sync(force)
- msg = "Underlying provider does not support context management"
- raise NotImplementedError(msg)
+ def context_sync(self, force: bool = False) -> typing.ContextManager[CT_co]:
+ return typing.cast(typing.ContextManager[CT_co], self._call_context_method("context_sync", force))
@typing_extensions.override
def supports_context_sync(self) -> bool:
- provider = self._get_provider()
- return isinstance(provider, SupportsContext) and provider.supports_context_sync()
+ return typing.cast(bool, self._call_context_method("supports_context_sync"))
@typing_extensions.override
async def tear_down(self, propagate: bool = True) -> None:
@@ -139,6 +126,19 @@ def _get_provider(self) -> AbstractProvider[Any]:
self._provider = cast(AbstractProvider[Any], provider)
return self._provider
+ def _call_context_method(self, method_name: str, *args: object) -> object:
+ provider = self._get_provider()
+ try:
+ method = getattr(type(provider), method_name)
+ except AttributeError as e:
+ msg = "Underlying provider does not support context management"
+ raise NotImplementedError(msg) from e
+ try:
+ return method(provider, *args)
+ except AttributeError as e:
+ msg = "Underlying provider does not support context management"
+ raise NotImplementedError(msg) from e
+
@typing_extensions.override
async def resolve(self) -> Any:
provider = self._get_provider()
diff --git a/that_depends/injection.py b/that_depends/injection.py
index 3a48dba7..bd020b73 100644
--- a/that_depends/injection.py
+++ b/that_depends/injection.py
@@ -1,19 +1,18 @@
-import asyncio
-import contextlib
import functools
import inspect
import re
-import threading
import typing
import warnings
-from contextlib import AsyncExitStack, ExitStack
+from contextlib import AsyncExitStack
+from types import TracebackType
+
+from typing_extensions import Self
from that_depends.container import BaseContainer
from that_depends.exceptions import TypeNotBoundError
from that_depends.meta import BaseContainerMeta
-from that_depends.providers import AbstractProvider, ContextResource
+from that_depends.providers import AbstractProvider
from that_depends.providers.context_resources import ContextScope, ContextScopes, container_context
-from that_depends.providers.mixin import ProviderWithArguments
class ContextProviderError(Exception):
@@ -29,6 +28,107 @@ class ContextProviderError(Exception):
)
+class _DirectInjectionParameter(typing.NamedTuple):
+ argument_index: int
+ field_name: str
+ provider: AbstractProvider[typing.Any]
+ scope_context_init_order: tuple[AbstractProvider[typing.Any], ...]
+
+
+class _StringInjectionParameter(typing.NamedTuple):
+ argument_index: int
+ field_name: str
+ definition: "StringProviderDefinition"
+
+
+class _TypedInjectionParameter(typing.NamedTuple):
+ argument_index: int
+ field_name: str
+ annotation: type[typing.Any]
+
+
+class _InjectionPlan(typing.NamedTuple):
+ direct_parameters: tuple[_DirectInjectionParameter, ...]
+ string_parameters: tuple[_StringInjectionParameter, ...]
+ typed_parameters: tuple[_TypedInjectionParameter, ...]
+
+
+class _SyncInjectionStack:
+ __slots__ = ("_exit_states",)
+
+ def __init__(self) -> None:
+ self._exit_states: list[_SupportsClose] = []
+
+ def __enter__(self) -> Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> typing.Literal[False]:
+ _ = exc_type, exc_value, traceback
+ while self._exit_states:
+ self._exit_states.pop().close()
+ return False
+
+ def enter_context(self, context_manager: typing.ContextManager[T]) -> T:
+ value = context_manager.__enter__()
+ self._exit_states.append(_ContextManagerExitState(context_manager))
+ return value
+
+ def push_exit_state(self, exit_state: "_SupportsClose") -> None:
+ self._exit_states.append(exit_state)
+
+
+class _SupportsClose(typing.Protocol):
+ def close(self) -> None: ...
+
+
+class _ContextManagerExitState:
+ __slots__ = ("_context_manager",)
+
+ def __init__(self, context_manager: typing.ContextManager[typing.Any]) -> None:
+ self._context_manager = context_manager
+
+ def close(self) -> None:
+ self._context_manager.__exit__(None, None, None)
+
+
+@functools.cache
+def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionPlan:
+ direct_parameters: list[_DirectInjectionParameter] = []
+ string_parameters: list[_StringInjectionParameter] = []
+ typed_parameters: list[_TypedInjectionParameter] = []
+ for index, (field_name, param) in enumerate(inspect.signature(func).parameters.items()):
+ default = param.default
+ if isinstance(default, StringProviderDefinition):
+ string_parameters.append(_StringInjectionParameter(index, field_name, default))
+ elif isinstance(default, AbstractProvider):
+ direct_parameters.append(
+ _DirectInjectionParameter(
+ index,
+ field_name,
+ default,
+ default._get_scope_context_init_order(), # noqa: SLF001
+ )
+ )
+ elif isinstance(default, _Provide):
+ typed_parameters.append(
+ _TypedInjectionParameter(
+ index,
+ field_name,
+ typing.cast(type[typing.Any], param.annotation),
+ )
+ )
+ return _InjectionPlan(
+ direct_parameters=tuple(direct_parameters),
+ string_parameters=tuple(string_parameters),
+ typed_parameters=tuple(typed_parameters),
+ )
+
+
@typing.overload
def inject(func: typing.Callable[P, T]) -> typing.Callable[P, T]: ...
@@ -88,10 +188,11 @@ def _inject(
def _inject_to_sync_gen(
gen: typing.Callable[P, typing.Generator[T, typing.Any, typing.Any]],
) -> typing.Callable[P, typing.Generator[T, typing.Any, typing.Any]]:
+ plan = _build_injection_plan(gen)
+
@functools.wraps(gen)
def inner(*args: P.args, **kwargs: P.kwargs) -> typing.Generator[T, typing.Any, typing.Any]:
- signature = inspect.signature(gen)
- injected, kwargs = _resolve_arguments_sync(signature, scope, container, None, *args, **kwargs) # type: ignore[assignment]
+ injected, kwargs = _resolve_arguments_sync(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment]
if not injected:
warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=2)
@@ -105,11 +206,11 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> typing.Generator[T, typing.Any,
def _inject_to_async_gen(
gen: typing.Callable[P, typing.AsyncGenerator[T, typing.Any]],
) -> typing.Callable[P, typing.AsyncGenerator[T, typing.Any]]:
+ plan = _build_injection_plan(gen)
+
@functools.wraps(gen)
async def inner(*args: P.args, **kwargs: P.kwargs) -> typing.AsyncGenerator[T, typing.Any]:
- signature = inspect.signature(gen)
-
- injected, kwargs = await _resolve_arguments_async(signature, scope, container, None, *args, **kwargs) # type: ignore[assignment]
+ injected, kwargs = await _resolve_arguments_async(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment]
if not injected:
warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=1)
@@ -122,26 +223,30 @@ async def inner(*args: P.args, **kwargs: P.kwargs) -> typing.AsyncGenerator[T, t
def _inject_to_async(
func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]],
) -> typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]]:
+ plan = _build_injection_plan(func)
+
@functools.wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if enter_scope:
async with container_context(scope=scope):
- return await _resolve_async(func, None, container, *args, **kwargs)
+ return await _resolve_async(func, plan, None, container, *args, **kwargs)
else:
- return await _resolve_async(func, scope, container, *args, **kwargs)
+ return await _resolve_async(func, plan, scope, container, *args, **kwargs)
return inner
def _inject_to_sync(
func: typing.Callable[P, T],
) -> typing.Callable[P, T]:
+ plan = _build_injection_plan(func)
+
@functools.wraps(func)
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if enter_scope:
with container_context(scope=scope):
- return _resolve_sync(func, None, container, *args, **kwargs)
+ return _resolve_sync(func, plan, None, container, *args, **kwargs)
else:
- return _resolve_sync(func, scope, container, *args, **kwargs)
+ return _resolve_sync(func, plan, scope, container, *args, **kwargs)
return inner
@@ -150,131 +255,169 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T:
return _inject(func)
-_SYNC_SIGNATURE_CACHE: dict[typing.Callable[..., typing.Any], inspect.Signature] = {}
-_THREADING_LOCK = threading.Lock()
-
-
async def _resolve_arguments_async(
- signature: inspect.Signature,
+ plan: _InjectionPlan,
scope: ContextScope | None,
container: BaseContainerMeta | None,
stack: AsyncExitStack | None,
*args: typing.Any, # noqa: ANN401
**kwargs: typing.Any, # noqa: ANN401
) -> tuple[bool, dict[str, typing.Any]]:
- injected = False
- context_providers: set[AbstractProvider[typing.Any]] = set()
- params = list(signature.parameters.items())
+ if not _plan_has_injected_parameters(plan):
+ return False, kwargs
- for i, (field_name, param) in enumerate(params):
- default = param.default
+ context_providers: set[AbstractProvider[typing.Any]] = set()
+ for direct_parameter in plan.direct_parameters:
+ if _is_argument_provided(direct_parameter.argument_index, direct_parameter.field_name, args, kwargs):
+ continue
- if i < len(args) or field_name in kwargs:
- if isinstance(default, (AbstractProvider, StringProviderDefinition)):
- injected = True
+ if direct_parameter.scope_context_init_order:
+ await _setup_scope_contexts_async(
+ direct_parameter.scope_context_init_order,
+ scope,
+ stack,
+ context_providers,
+ )
+ kwargs[direct_parameter.field_name] = await direct_parameter.provider.resolve()
+
+ for string_parameter in plan.string_parameters:
+ if _is_argument_provided(string_parameter.argument_index, string_parameter.field_name, args, kwargs):
continue
- if isinstance(default, StringProviderDefinition):
- injected = True
- resolved_val = await _resolve_provider_with_scope_async(default.provider, scope, stack, context_providers)
- kwargs[field_name] = resolved_val
- elif isinstance(default, AbstractProvider):
- injected = True
- resolved_val = await _resolve_provider_with_scope_async(default, scope, stack, context_providers)
- kwargs[field_name] = resolved_val
+ kwargs[string_parameter.field_name] = await _resolve_provider_with_scope_async(
+ string_parameter.definition.provider,
+ scope,
+ stack,
+ context_providers,
+ )
- elif isinstance(default, _Provide):
- injected = True
- if container is None:
- raise RuntimeError(_PROVIDE_MESSAGE)
- try:
- provider = container.get_provider_for_type(signature.parameters[field_name].annotation)
- except TypeNotBoundError as e:
- msg = f"Type {signature.parameters[field_name].annotation} is not bound to a provider."
- raise RuntimeError(msg) from e
- kwargs[field_name] = await _resolve_provider_with_scope_async(provider, scope, stack, context_providers)
- return injected, kwargs
+ for typed_parameter in plan.typed_parameters:
+ if _is_argument_provided(typed_parameter.argument_index, typed_parameter.field_name, args, kwargs):
+ continue
+
+ provider = _resolve_typed_provider(typed_parameter.annotation, container)
+ kwargs[typed_parameter.field_name] = await _resolve_provider_with_scope_async(
+ provider,
+ scope,
+ stack,
+ context_providers,
+ )
+ return True, kwargs
def _resolve_arguments_sync(
- signature: inspect.Signature,
+ plan: _InjectionPlan,
scope: ContextScope | None,
container: BaseContainerMeta | None,
- stack: contextlib.ExitStack | None,
+ stack: _SyncInjectionStack | None,
*args: typing.Any, # noqa: ANN401
**kwargs: typing.Any, # noqa: ANN401
) -> tuple[bool, dict[str, typing.Any]]:
- injected = False
+ if not _plan_has_injected_parameters(plan):
+ return False, kwargs
+
context_providers: set[AbstractProvider[typing.Any]] = set()
- for i, (field_name, param) in enumerate(signature.parameters.items()):
- default = param.default
- if i < len(args) or field_name in kwargs:
- if isinstance(default, (AbstractProvider, StringProviderDefinition, _Provide)):
- injected = True
+ for direct_parameter in plan.direct_parameters:
+ if _is_argument_provided(direct_parameter.argument_index, direct_parameter.field_name, args, kwargs):
continue
- if isinstance(default, StringProviderDefinition):
- injected = True
- kwargs[field_name] = _resolve_provider_with_scope_sync(default.provider, scope, stack, context_providers)
- elif isinstance(default, AbstractProvider):
- injected = True
- kwargs[field_name] = _resolve_provider_with_scope_sync(default, scope, stack, context_providers)
- elif isinstance(default, _Provide):
- injected = True
- if container is None:
- raise RuntimeError(_PROVIDE_MESSAGE)
- try:
- provider = container.get_provider_for_type(signature.parameters[field_name].annotation)
- except TypeNotBoundError as e:
- msg = f"Type {signature.parameters[field_name].annotation} is not bound to a provider."
- raise RuntimeError(msg) from e
- kwargs[field_name] = _resolve_provider_with_scope_sync(provider, scope, stack, context_providers)
+ if direct_parameter.scope_context_init_order:
+ _setup_scope_contexts_sync(
+ direct_parameter.scope_context_init_order,
+ scope,
+ stack,
+ context_providers,
+ )
+ kwargs[direct_parameter.field_name] = direct_parameter.provider.resolve_sync()
+
+ for string_parameter in plan.string_parameters:
+ if _is_argument_provided(string_parameter.argument_index, string_parameter.field_name, args, kwargs):
+ continue
+
+ kwargs[string_parameter.field_name] = _resolve_provider_with_scope_sync(
+ string_parameter.definition.provider,
+ scope,
+ stack,
+ context_providers,
+ )
+
+ for typed_parameter in plan.typed_parameters:
+ if _is_argument_provided(typed_parameter.argument_index, typed_parameter.field_name, args, kwargs):
+ continue
+
+ provider = _resolve_typed_provider(typed_parameter.annotation, container)
+ kwargs[typed_parameter.field_name] = _resolve_provider_with_scope_sync(
+ provider,
+ scope,
+ stack,
+ context_providers,
+ )
+
+ return True, kwargs
+
+
+def _plan_has_injected_parameters(plan: _InjectionPlan) -> bool:
+ return bool(plan.direct_parameters or plan.string_parameters or plan.typed_parameters)
+
+
+def _is_argument_provided(
+ argument_index: int,
+ field_name: str,
+ args: tuple[typing.Any, ...],
+ kwargs: dict[str, typing.Any],
+) -> bool:
+ return argument_index < len(args) or field_name in kwargs
+
- return injected, kwargs
+def _resolve_typed_provider(
+ annotation: type[typing.Any],
+ container: BaseContainerMeta | None,
+) -> AbstractProvider[typing.Any]:
+ if container is None:
+ raise RuntimeError(_PROVIDE_MESSAGE)
+ try:
+ return container.get_provider_for_type(annotation)
+ except TypeNotBoundError as e:
+ msg = f"Type {annotation} is not bound to a provider."
+ raise RuntimeError(msg) from e
def _resolve_sync(
func: typing.Callable[P, T],
+ plan: _InjectionPlan,
scope: ContextScope | None,
container: BaseContainerMeta | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
- if func not in _SYNC_SIGNATURE_CACHE:
- with _THREADING_LOCK:
- _SYNC_SIGNATURE_CACHE[func] = inspect.signature(func)
- signature = _SYNC_SIGNATURE_CACHE[func]
-
- with ExitStack() as stack:
- injected, kwargs = _resolve_arguments_sync(signature, scope, container, stack, *args, **kwargs) # type: ignore[assignment]
-
+ if scope is None:
+ injected, kwargs = _resolve_arguments_sync(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment]
if not injected:
warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=3)
return func(*args, **kwargs)
+ with _SyncInjectionStack() as stack:
+ injected, kwargs = _resolve_arguments_sync(plan, scope, container, stack, *args, **kwargs) # type: ignore[assignment]
-_SIGNATURE_CACHE: dict[
- typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, typing.Any]], inspect.Signature
-] = {}
+ if not injected:
+ warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=3)
-_ASYNCIO_LOCK = asyncio.Lock()
+ return func(*args, **kwargs)
+
+ raise RuntimeError # pragma: no cover # to prevent mypy issue
async def _resolve_async(
func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]],
+ plan: _InjectionPlan,
scope: ContextScope | None,
container: BaseContainerMeta | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
- if func not in _SIGNATURE_CACHE:
- async with _ASYNCIO_LOCK:
- _SIGNATURE_CACHE[func] = inspect.signature(func)
- signature = _SIGNATURE_CACHE[func]
-
async with AsyncExitStack() as stack:
- injected, kwargs = await _resolve_arguments_async(signature, scope, container, stack, *args, **kwargs) # type: ignore[assignment]
+ injected, kwargs = await _resolve_arguments_async(plan, scope, container, stack, *args, **kwargs) # type: ignore[assignment]
if not injected:
warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=1)
@@ -306,30 +449,25 @@ async def _resolve_provider_with_scope_async(
ContextProviderError: if the stack is None.
"""
- await _add_provider_to_stack_async(provider, stack, scope, providers)
+ scope_context_init_order = provider._get_scope_context_init_order() # noqa: SLF001
+ if scope_context_init_order:
+ await _setup_scope_contexts_async(scope_context_init_order, scope, stack, providers)
return await provider.resolve()
-async def _add_provider_to_stack_async(
- provider: AbstractProvider[T],
- stack: AsyncExitStack | None,
+async def _setup_scope_contexts_async(
+ scope_init_order: tuple[AbstractProvider[typing.Any], ...],
scope: ContextScope | None,
+ stack: AsyncExitStack | None,
providers: set[AbstractProvider[typing.Any]],
) -> None:
- if provider in providers:
- return
- providers.add(provider)
-
if not scope:
return
- if isinstance(provider, ProviderWithArguments):
- provider._register_arguments() # noqa: SLF001
-
- parents = provider._parents # noqa: SLF001
- for parent in parents:
- await _add_provider_to_stack_async(parent, stack, scope, providers)
- if isinstance(provider, ContextResource):
- provider_scope = provider.get_scope()
+ for provider in scope_init_order:
+ if provider in providers:
+ continue
+ providers.add(provider)
+ provider_scope = provider._scope # noqa: SLF001
if provider_scope in (ContextScopes.ANY, scope):
if stack is None:
msg = (
@@ -343,34 +481,28 @@ async def _add_provider_to_stack_async(
def _resolve_provider_with_scope_sync(
provider: AbstractProvider[T],
scope: ContextScope | None,
- stack: ExitStack | None,
+ stack: _SyncInjectionStack | None,
providers: set[AbstractProvider[typing.Any]],
) -> T:
- _add_provider_to_stack_sync(provider, stack, scope, providers)
+ scope_context_init_order = provider._get_scope_context_init_order() # noqa: SLF001
+ if scope_context_init_order:
+ _setup_scope_contexts_sync(scope_context_init_order, scope, stack, providers)
return provider.resolve_sync()
-def _add_provider_to_stack_sync(
- provider: AbstractProvider[T],
- stack: ExitStack | None,
+def _setup_scope_contexts_sync(
+ scope_init_order: tuple[AbstractProvider[typing.Any], ...],
scope: ContextScope | None,
+ stack: _SyncInjectionStack | None,
providers: set[AbstractProvider[typing.Any]],
) -> None:
- if provider in providers:
- return
- providers.add(provider)
-
if not scope:
return
- if isinstance(provider, ProviderWithArguments):
- provider._register_arguments() # noqa: SLF001
-
- parents = provider._parents # noqa: SLF001
- for parent in parents:
- _add_provider_to_stack_sync(parent, stack, scope, providers)
-
- if isinstance(provider, ContextResource):
- provider_scope = provider.get_scope()
+ for provider in scope_init_order:
+ if provider in providers:
+ continue
+ providers.add(provider)
+ provider_scope = provider._scope # noqa: SLF001
if provider_scope in (ContextScopes.ANY, scope):
if stack is None:
msg = (
@@ -378,7 +510,8 @@ def _add_provider_to_stack_sync(
f"Note: @inject cannot initialize context for ContextResources when wrapping a generator."
)
raise ContextProviderError(msg)
- stack.enter_context(provider.context_sync(force=True))
+ _, exit_state = provider._enter_injection_context_sync(force=True) # noqa: SLF001
+ stack.push_exit_state(exit_state)
class StringProviderDefinition:
diff --git a/that_depends/integrations/faststream.py b/that_depends/integrations/faststream.py
index 67d9fc64..2120f5bf 100644
--- a/that_depends/integrations/faststream.py
+++ b/that_depends/integrations/faststream.py
@@ -22,7 +22,7 @@ class DIContextMiddleware(BaseMiddleware):
def __init__(
self,
- *context_items: SupportsContext[Any],
+ *context_items: SupportsContext[typing.Any],
msg: AnyMsg | None = None,
context: Optional["ContextRepo"] = None,
global_context: dict[str, Any] | Unset = UNSET,
@@ -31,7 +31,7 @@ def __init__(
"""Initialize the container context middleware.
Args:
- *context_items (SupportsContext[Any]): Context items to initialize.
+ *context_items (SupportsContext[Any]): Context-capable providers or container classes to initialize.
msg (Any): Message object.
context (ContextRepo): Context repository.
global_context (dict[str, Any] | Unset): Global context to initialize the container.
@@ -40,7 +40,7 @@ def __init__(
"""
super().__init__(msg, context=context) # type: ignore[arg-type]
self._context: container_context | None = None
- self._context_items = set(context_items)
+ self._context_items: set[SupportsContext[typing.Any]] = set(context_items)
self._global_context = global_context
self._scope = scope
@@ -79,8 +79,8 @@ def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": #
return DIContextMiddleware(
*self._context_items,
- msg=msg,
- context=context,
+ msg=msg, # type:ignore[unexpected-keyword]
+ context=context, # type:ignore[unexpected-keyword]
scope=self._scope,
global_context=self._global_context,
)
@@ -93,21 +93,21 @@ class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef]
def __init__(
self,
- *context_items: SupportsContext[Any],
+ *context_items: SupportsContext[typing.Any],
global_context: dict[str, Any] | Unset = UNSET,
scope: ContextScope | Unset = UNSET,
) -> None:
"""Initialize the container context middleware.
Args:
- *context_items (SupportsContext[Any]): Context items to initialize.
+ *context_items (SupportsContext[Any]): Context-capable providers or container classes to initialize.
global_context (dict[str, Any] | Unset): Global context to initialize the container.
scope (ContextScope | Unset): Context scope to initialize the container.
"""
super().__init__() # type: ignore[call-arg]
self._context: container_context | None = None
- self._context_items = set(context_items)
+ self._context_items: set[SupportsContext[typing.Any]] = set(context_items)
self._global_context = global_context
self._scope = scope
diff --git a/that_depends/meta.py b/that_depends/meta.py
index 21507a49..614479d0 100644
--- a/that_depends/meta.py
+++ b/that_depends/meta.py
@@ -2,7 +2,6 @@
import typing
import warnings
from collections.abc import MutableMapping
-from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from threading import Lock
from typing import TYPE_CHECKING
@@ -11,7 +10,7 @@
import that_depends
from that_depends.exceptions import TypeNotBoundError
from that_depends.providers import AbstractProvider, Resource
-from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes, SupportsContext
+from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes
from that_depends.providers.mixin import SupportsTeardown
@@ -43,41 +42,15 @@ def __setitem__(self, key: str, value: typing.Any) -> None:
super().__setitem__(key, value)
-class BaseContainerMeta(abc.ABCMeta, SupportsContext[None]):
+class BaseContainerMeta(abc.ABCMeta):
"""Metaclass for BaseContainer."""
- @override
def get_scope(cls) -> ContextScope | None:
+ """Return the default scope used by the container."""
if scope := getattr(cls, "default_scope", None):
return typing.cast(ContextScope | None, scope)
return ContextScopes.ANY
- @asynccontextmanager
- @override
- async def context_async(cls, force: bool = False) -> typing.AsyncIterator[None]:
- async with AsyncExitStack() as stack:
- for container in cls.get_containers():
- await stack.enter_async_context(container.context_async(force=force))
- for provider in cls.get_providers().values():
- if isinstance(provider, ContextResource):
- await stack.enter_async_context(provider.context_async(force=force))
- yield
-
- @contextmanager
- @override
- def context_sync(cls, force: bool = False) -> typing.Iterator[None]:
- with ExitStack() as stack:
- for container in cls.get_containers():
- stack.enter_context(container.context_sync(force=force))
- for provider in cls.get_providers().values():
- if isinstance(provider, ContextResource) and not provider._is_async: # noqa: SLF001
- stack.enter_context(provider.context_sync(force=force))
- yield
-
- @override
- def supports_context_sync(cls) -> bool:
- return True
-
_instances: typing.ClassVar[dict[str, type["BaseContainer"]]] = {}
_MUTABLE_ATTRS = (
@@ -88,6 +61,7 @@ def supports_context_sync(cls) -> bool:
"containers",
"alias",
"default_scope",
+ "type_provider_cache",
)
_lock: Lock = Lock()
@@ -114,10 +88,6 @@ def name(cls) -> str:
def __prepare__(cls, name: str, bases: tuple[type, ...], /, **kwds: typing.Any) -> MutableMapping[str, object]:
return _ContainerMetaDict()
- @override
- def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type:
- return super().__new__(cls, name, bases, namespace)
-
@classmethod
def get_instances(cls) -> dict[str, type["BaseContainer"]]:
"""Get all instances that inherit from BaseContainer."""
@@ -163,12 +133,18 @@ def get_provider_for_type(cls, t: type[T]) -> AbstractProvider[T]:
Provider for the given type.
"""
+ if not hasattr(cls, "type_provider_cache"):
+ cls.type_provider_cache: dict[type[typing.Any], AbstractProvider[typing.Any]] = {}
+ if provider := cls.type_provider_cache.get(t):
+ return typing.cast(AbstractProvider[T], provider)
for provider in cls.get_providers().values():
if provider._has_contravariant_bindings: # noqa: SLF001
for bind in provider._bindings: # noqa: SLF001
if issubclass(bind, t):
+ cls.type_provider_cache[t] = provider
return provider
elif t in provider._bindings: # noqa: SLF001
+ cls.type_provider_cache[t] = provider
return provider
msg = f"Type {t} is not bound to any provider in container {cls.name()}"
raise TypeNotBoundError(msg)
diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py
index 9c30bdfe..7adc02d1 100644
--- a/that_depends/providers/base.py
+++ b/that_depends/providers/base.py
@@ -11,6 +11,7 @@
from that_depends.entities.resource_context import ResourceContext
from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown
+from that_depends.utils import UNSET, is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -18,8 +19,71 @@
P = typing.ParamSpec("P")
ResourceCreatorType: typing.TypeAlias = typing.Callable[
P,
- typing.Iterator[T_co] | typing.AsyncIterator[T_co] | typing.ContextManager[T_co] | typing.AsyncContextManager[T_co],
+ typing.Iterator[T_co]
+ | typing.AsyncIterator[T_co]
+ | contextlib.AbstractContextManager[T_co]
+ | contextlib.AbstractAsyncContextManager[T_co],
]
+_EMPTY_ARGS: typing.Final[tuple[()]] = ()
+_EMPTY_KWARGS: typing.Final[dict[str, typing.Any]] = {}
+
+
+async def _resolve_arguments(
+ args: tuple[typing.Any, ...],
+ args_are_providers: tuple[bool, ...],
+) -> tuple[typing.Any, ...] | list[typing.Any]:
+ if not args:
+ return _EMPTY_ARGS
+ if len(args) == 1:
+ arg = args[0]
+ return (await arg.resolve() if args_are_providers[0] else arg,)
+ return [
+ await arg.resolve() if is_provider else arg for arg, is_provider in zip(args, args_are_providers, strict=False)
+ ]
+
+
+def _resolve_arguments_sync(
+ args: tuple[typing.Any, ...],
+ args_are_providers: tuple[bool, ...],
+) -> tuple[typing.Any, ...] | list[typing.Any]:
+ if not args:
+ return _EMPTY_ARGS
+ if len(args) == 1:
+ arg = args[0]
+ return (arg.resolve_sync() if args_are_providers[0] else arg,)
+ return [
+ arg.resolve_sync() if is_provider else arg for arg, is_provider in zip(args, args_are_providers, strict=False)
+ ]
+
+
+async def _resolve_keyword_arguments(
+ kwargs_items: tuple[tuple[str, typing.Any], ...],
+ kwargs_are_providers: tuple[bool, ...],
+) -> dict[str, typing.Any]:
+ if not kwargs_items:
+ return _EMPTY_KWARGS
+ if len(kwargs_items) == 1:
+ key, value = kwargs_items[0]
+ return {key: await value.resolve() if kwargs_are_providers[0] else value}
+ return {
+ key: await value.resolve() if is_provider else value
+ for (key, value), is_provider in zip(kwargs_items, kwargs_are_providers, strict=False)
+ }
+
+
+def _resolve_keyword_arguments_sync(
+ kwargs_items: tuple[tuple[str, typing.Any], ...],
+ kwargs_are_providers: tuple[bool, ...],
+) -> dict[str, typing.Any]:
+ if not kwargs_items:
+ return _EMPTY_KWARGS
+ if len(kwargs_items) == 1:
+ key, value = kwargs_items[0]
+ return {key: value.resolve_sync() if kwargs_are_providers[0] else value}
+ return {
+ key: value.resolve_sync() if is_provider else value
+ for (key, value), is_provider in zip(kwargs_items, kwargs_are_providers, strict=False)
+ }
class AbstractProvider(abc.ABC, typing.Generic[T_co]):
@@ -30,7 +94,10 @@ def __init__(self) -> None:
super().__init__()
self._children: set[AbstractProvider[typing.Any]] = set()
self._parents: set[AbstractProvider[typing.Any]] = set()
- self._override: typing.Any = None
+ self._is_context_resource = False
+ self._scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None
+ self._scope_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None
+ self._override: typing.Any = UNSET
self._bindings: set[type] = set()
self._has_contravariant_bindings: bool = False
self._lock = threading.Lock()
@@ -62,10 +129,14 @@ def _register(self, candidates: typing.Iterable[typing.Any]) -> None:
None
"""
+ changed = False
for candidate in candidates:
if isinstance(candidate, AbstractProvider):
candidate.add_child_provider(self)
self._parents.add(candidate)
+ changed = True
+ if changed:
+ self._invalidate_scope_init_order()
def _deregister(self, candidates: typing.Iterable[typing.Any]) -> None:
"""Deregister current provider as child.
@@ -77,10 +148,71 @@ def _deregister(self, candidates: typing.Iterable[typing.Any]) -> None:
None
"""
+ changed = False
for candidate in candidates:
if isinstance(candidate, AbstractProvider) and self in candidate._children: # noqa: SLF001
candidate.remove_child_provider(self)
self._parents.discard(candidate)
+ changed = True
+ if changed:
+ self._invalidate_scope_init_order()
+
+ def _invalidate_scope_init_order(self) -> None:
+ stack = [self]
+ visited: set[AbstractProvider[typing.Any]] = set()
+
+ while stack:
+ provider = stack.pop()
+ if provider in visited:
+ continue
+ visited.add(provider)
+ provider._scope_context_init_order = None # noqa: SLF001
+ provider._scope_init_order = None # noqa: SLF001
+ stack.extend(provider._children) # noqa: SLF001
+
+ def _get_scope_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]:
+ if self._scope_init_order is not None:
+ return self._scope_init_order
+
+ if isinstance(self, ProviderWithArguments):
+ self._register_arguments()
+
+ ordered: list[AbstractProvider[typing.Any]] = []
+ seen: set[AbstractProvider[typing.Any]] = set()
+
+ for parent in self._parents:
+ for ancestor in parent._get_scope_init_order(): # noqa: SLF001
+ if ancestor not in seen:
+ seen.add(ancestor)
+ ordered.append(ancestor)
+
+ if self not in seen:
+ ordered.append(self)
+
+ self._scope_init_order = tuple(ordered)
+ return self._scope_init_order
+
+ def _get_scope_context_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]:
+ if self._scope_context_init_order is not None:
+ return self._scope_context_init_order
+
+ if isinstance(self, ProviderWithArguments):
+ self._register_arguments()
+
+ ordered: list[AbstractProvider[typing.Any]] = []
+ seen: set[AbstractProvider[typing.Any]] = set()
+
+ for parent in self._parents:
+ for ancestor in parent._get_scope_context_init_order(): # noqa: SLF001
+ if ancestor not in seen:
+ seen.add(ancestor)
+ ordered.append(ancestor)
+
+ if self._is_context_resource and self not in seen:
+ ordered.append(self)
+
+ self._scope_context_init_order = tuple(ordered)
+ return self._scope_context_init_order
def add_child_provider(self, provider: "AbstractProvider[typing.Any]") -> None:
"""Add a child provider to the current provider.
@@ -232,7 +364,7 @@ def reset_override_sync(
None
"""
- self._override = None
+ self._override = UNSET
if tear_down_children:
eligible_children = [child for child in self._children if isinstance(child, SupportsTeardown)]
for child in eligible_children:
@@ -249,7 +381,7 @@ async def reset_override(self, tear_down_children: bool = False, propagate: bool
None
"""
- self._override = None
+ self._override = UNSET
if tear_down_children:
eligible_children = [child for child in self._children if isinstance(child, SupportsTeardown)]
for child in eligible_children:
@@ -306,7 +438,10 @@ def __init__(
"""
super().__init__()
- self._creator: typing.Any
+ self._creator: typing.Callable[
+ ...,
+ contextlib.AbstractContextManager[T_co] | contextlib.AbstractAsyncContextManager[T_co],
+ ]
if inspect.isasyncgenfunction(creator):
self._is_async = True
@@ -314,10 +449,10 @@ def __init__(
elif inspect.isgeneratorfunction(creator):
self._is_async = False
self._creator = contextlib.contextmanager(creator)
- elif isinstance(creator, type) and issubclass(creator, typing.AsyncContextManager):
+ elif isinstance(creator, type) and hasattr(creator, "__aenter__") and hasattr(creator, "__aexit__"):
self._is_async = True
self._creator = creator
- elif isinstance(creator, type) and issubclass(creator, typing.ContextManager):
+ elif isinstance(creator, type) and hasattr(creator, "__enter__") and hasattr(creator, "__exit__"):
self._is_async = False
self._creator = creator
else:
@@ -325,60 +460,65 @@ def __init__(
raise TypeError(msg)
self._args = args
self._kwargs = kwargs
+ self._args_are_providers = tuple(isinstance(arg, AbstractProvider) for arg in args)
+ self._kwargs_items = tuple(kwargs.items())
+ self._kwargs_are_providers = tuple(isinstance(value, AbstractProvider) for _, value in self._kwargs_items)
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
self._register(self._args)
self._register(self._kwargs.values())
def _deregister_arguments(self) -> None:
self._deregister(self._args)
self._deregister(self._kwargs.values())
+ self._reset_arguments_registration()
@abc.abstractmethod
def _fetch_context(self) -> ResourceContext[T_co]: ...
@override
async def resolve(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
context = self._fetch_context()
# lock to prevent race condition while resolving
async with context.asyncio_lock:
- if context.instance is not None:
+ if is_set(context.instance):
return context.instance
self._register_arguments()
-
- cm: typing.ContextManager[T_co] | typing.AsyncContextManager[T_co] = self._creator(
- *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
- **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
+ cm: contextlib.AbstractContextManager[T_co] | contextlib.AbstractAsyncContextManager[T_co] = self._creator(
+ *await _resolve_arguments(self._args, self._args_are_providers),
+ **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers),
)
- if isinstance(cm, typing.AsyncContextManager):
+ if isinstance(cm, contextlib.AbstractAsyncContextManager):
context.context_stack = contextlib.AsyncExitStack()
context.instance = await context.context_stack.enter_async_context(cm)
- elif isinstance(cm, typing.ContextManager):
+ elif isinstance(cm, contextlib.AbstractContextManager):
context.context_stack = contextlib.ExitStack()
context.instance = context.context_stack.enter_context(cm)
else: # pragma: no cover
- typing.assert_never(cm)
+ typing_extensions.assert_never(cm)
return context.instance
@override
def resolve_sync(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
context = self._fetch_context()
# lock to prevent race condition while resolving
with context.threading_lock:
- if context.instance is not None:
+ if is_set(context.instance):
return context.instance
if self._is_async:
@@ -386,15 +526,15 @@ def resolve_sync(self) -> T_co:
raise RuntimeError(msg)
self._register_arguments()
-
cm = self._creator(
- *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args],
- **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
+ *_resolve_arguments_sync(self._args, self._args_are_providers),
+ **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers),
)
context.context_stack = contextlib.ExitStack()
- context.instance = context.context_stack.enter_context(cm)
+ instance: T_co = context.context_stack.enter_context(cm) # type: ignore[arg-type]
+ context.instance = instance
- return context.instance
+ return instance
def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401
@@ -409,6 +549,8 @@ class AttrGetter(
"""Provides an attribute after resolving the wrapped provider."""
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
if isinstance(self._provider, ProviderWithArguments):
self._provider._register_arguments() # noqa: SLF001
self._parents = self._provider._parents # noqa: SLF001
@@ -417,6 +559,7 @@ def _deregister_arguments(self) -> None:
if isinstance(self._provider, ProviderWithArguments):
self._provider._deregister_arguments() # noqa: SLF001
self._parents = self._provider._parents # noqa: SLF001
+ self._reset_arguments_registration()
__slots__ = "_attrs", "_provider"
@@ -433,11 +576,11 @@ def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None:
self._attrs = [attr_name]
@override
- def __getattr__(self, attr: str) -> "AttrGetter[T_co]":
- if attr.startswith("_"):
- msg = f"'{type(self)}' object has no attribute '{attr}'"
+ def __getattr__(self, attr_name: str) -> "AttrGetter[T_co]":
+ if attr_name.startswith("_"):
+ msg = f"'{type(self)}' object has no attribute '{attr_name}'"
raise AttributeError(msg)
- self._attrs.append(attr)
+ self._attrs.append(attr_name)
return self
@override
diff --git a/that_depends/providers/collection.py b/that_depends/providers/collection.py
index 28533d11..08e0a4ca 100644
--- a/that_depends/providers/collection.py
+++ b/that_depends/providers/collection.py
@@ -1,4 +1,6 @@
import typing
+from collections.abc import Mapping, Sequence
+from types import MappingProxyType
from typing_extensions import override
@@ -8,10 +10,11 @@
T_co = typing.TypeVar("T_co", covariant=True)
-class List(AbstractProvider[list[T_co]]):
- """Provides multiple resources as a list.
+class List(AbstractProvider[Sequence[T_co]]):
+ """Provides multiple resources as a read-only sequence.
- The `List` provider resolves multiple dependencies into a list.
+ The `List` provider resolves multiple dependencies into an immutable tuple
+ exposed through the `Sequence` interface.
Example:
```python
@@ -24,12 +27,12 @@ class List(AbstractProvider[list[T_co]]):
# Synchronous resolution
resolved_list = list_provider.resolve_sync()
- print(resolved_list) # Output: [1, 2]
+ print(tuple(resolved_list)) # Output: (1, 2)
# Asynchronous resolution
import asyncio
resolved_list_async = asyncio.run(list_provider.resolve())
- print(resolved_list_async) # Output: [1, 2]
+ print(tuple(resolved_list_async)) # Output: (1, 2)
```
"""
@@ -53,22 +56,24 @@ def __getattr__(self, attr_name: str) -> typing.Any:
raise AttributeError(msg)
@override
- async def resolve(self) -> list[T_co]:
- return [await x.resolve() for x in self._providers]
+ async def resolve(self) -> Sequence[T_co]:
+ resolved = [await x.resolve() for x in self._providers]
+ return tuple(resolved)
@override
- def resolve_sync(self) -> list[T_co]:
- return [x.resolve_sync() for x in self._providers]
+ def resolve_sync(self) -> Sequence[T_co]:
+ resolved = [x.resolve_sync() for x in self._providers]
+ return tuple(resolved)
@override
- async def __call__(self) -> list[T_co]:
+ async def __call__(self) -> Sequence[T_co]:
return await self.resolve()
-class Dict(AbstractProvider[dict[str, T_co]]):
- """Provides multiple resources as a dictionary.
+class Dict(AbstractProvider[Mapping[str, T_co]]):
+ """Provides multiple resources as a read-only mapping.
- The `Dict` provider resolves multiple named dependencies into a dictionary.
+ The `Dict` provider resolves multiple named dependencies into a read-only mapping.
Example:
```python
@@ -81,12 +86,12 @@ class Dict(AbstractProvider[dict[str, T_co]]):
# Synchronous resolution
resolved_dict = dict_provider.resolve_sync()
- print(resolved_dict) # Output: {"key1": 1, "key2": 2}
+ print(dict(resolved_dict)) # Output: {"key1": 1, "key2": 2}
# Asynchronous resolution
import asyncio
resolved_dict_async = asyncio.run(dict_provider.resolve())
- print(resolved_dict_async) # Output: {"key1": 1, "key2": 2}
+ print(dict(resolved_dict_async)) # Output: {"key1": 1, "key2": 2}
```
"""
@@ -110,9 +115,9 @@ def __getattr__(self, attr_name: str) -> typing.Any:
raise AttributeError(msg)
@override
- async def resolve(self) -> dict[str, T_co]:
- return {key: await provider.resolve() for key, provider in self._providers.items()}
+ async def resolve(self) -> Mapping[str, T_co]:
+ return MappingProxyType({key: await provider.resolve() for key, provider in self._providers.items()})
@override
- def resolve_sync(self) -> dict[str, T_co]:
- return {key: provider.resolve_sync() for key, provider in self._providers.items()}
+ def resolve_sync(self) -> Mapping[str, T_co]:
+ return MappingProxyType({key: provider.resolve_sync() for key, provider in self._providers.items()})
diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py
index 4dab61ab..549b3cd2 100644
--- a/that_depends/providers/context_resources.py
+++ b/that_depends/providers/context_resources.py
@@ -3,7 +3,6 @@
import inspect
import logging
import typing
-from abc import abstractmethod
from collections.abc import Iterable
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from contextvars import ContextVar, Token
@@ -11,10 +10,11 @@
from types import TracebackType
from typing import Final, overload
-from typing_extensions import TypeIs, override
+from typing_extensions import Protocol, TypeIs, override, runtime_checkable
from that_depends.entities.resource_context import ResourceContext
from that_depends.providers.base import AbstractResource
+from that_depends.utils import UNSET
if typing.TYPE_CHECKING:
@@ -41,6 +41,55 @@ class InvalidContextError(RuntimeError):
"""Raised when an invalid context is being used."""
+class _SyncInjectionExitState(typing.Generic[T_co]):
+ __slots__ = ("_context", "_context_item", "_token")
+
+ def __init__(
+ self,
+ context: ContextVar[ResourceContext[T_co]],
+ context_item: ResourceContext[T_co],
+ token: Token[ResourceContext[T_co]],
+ ) -> None:
+ self._context = context
+ self._context_item = context_item
+ self._token = token
+
+ def close(self) -> None:
+ context_stack = self._context_item.context_stack
+ if self._context_item.is_context_stack_sync(context_stack):
+ context_stack.close() # type: ignore[union-attr]
+ self._context_item.context_stack = UNSET
+ self._context_item.instance = UNSET
+ self._context.reset(self._token)
+
+
+class _SyncContextResourceContext(contextlib.ContextDecorator, AbstractContextManager[ResourceContext[T_co]]):
+ __slots__ = ("_exit_state", "_force", "_provider")
+
+ def __init__(self, provider: "ContextResource[T_co]", force: bool) -> None:
+ self._provider = provider
+ self._force = force
+ self._exit_state: _SyncInjectionExitState[T_co] | None = None
+
+ @override
+ def __enter__(self) -> ResourceContext[T_co]:
+ value, self._exit_state = self._provider._enter_injection_context_sync(force=self._force) # noqa: SLF001
+ return value
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ if self._exit_state is None:
+ msg = "Context is not set, call ``__enter__`` first"
+ raise RuntimeError(msg)
+ _ = exc_type, exc_value, traceback
+ self._exit_state.close()
+
+
class ContextScope:
"""A named context scope."""
@@ -102,22 +151,21 @@ def _enter_named_scope(scope: ContextScope) -> typing.Iterator[ContextScope]:
T = typing.TypeVar("T")
-CT = typing.TypeVar("CT")
+CT_co = typing.TypeVar("CT_co", covariant=True)
-class SupportsContext(typing.Generic[CT]):
+@runtime_checkable
+class SupportsContext(Protocol[CT_co]):
"""Interface for resources that support context initialization.
This interface defines methods to create synchronous and asynchronous
context managers, as well as a function decorator for context initialization.
"""
- @abstractmethod
def get_scope(self) -> ContextScope | None:
"""Return the scope of the resource."""
- @abstractmethod
- def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]:
+ def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT_co]:
"""Create an async context manager for this resource.
Args:
@@ -133,9 +181,9 @@ def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]:
```
"""
+ ...
- @abstractmethod
- def context_sync(self, force: bool = False) -> typing.ContextManager[CT]:
+ def context_sync(self, force: bool = False) -> typing.ContextManager[CT_co]:
"""Create a sync context manager for this resource.
Args:
@@ -151,8 +199,8 @@ def context_sync(self, force: bool = False) -> typing.ContextManager[CT]:
```
"""
+ ...
- @abstractmethod
def supports_context_sync(self) -> bool:
"""Check whether the resource supports sync context.
@@ -160,6 +208,10 @@ def supports_context_sync(self) -> bool:
bool: True if sync context is supported, False otherwise.
"""
+ ...
+
+
+BaseContainerType: typing.TypeAlias = type["BaseContainer"]
def _get_container_context() -> dict[str, typing.Any] | None:
@@ -228,16 +280,20 @@ class ContextResource(
@override
async def resolve(self) -> T_co:
+ if not self._strict_scope or self._scope == ContextScopes.ANY:
+ return await super().resolve()
current_scope = get_current_scope()
- if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope):
+ if self._scope == current_scope:
return await super().resolve()
msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`"
raise RuntimeError(msg)
@override
def resolve_sync(self) -> T_co:
+ if not self._strict_scope or self._scope == ContextScopes.ANY:
+ return super().resolve_sync()
current_scope = get_current_scope()
- if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope):
+ if self._scope == current_scope:
return super().resolve_sync()
msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`"
raise RuntimeError(msg)
@@ -274,7 +330,8 @@ def __init__(
"""
super().__init__(creator, *args, **kwargs)
- self._from_creator = creator
+ self._from_creator: typing.Callable[..., typing.Iterator[T_co] | typing.AsyncIterator[T_co]] = creator
+ self._is_context_resource = True
self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context")
self._token: Token[ResourceContext[T_co]] | None = None
self._async_lock: Final = asyncio.Lock()
@@ -334,7 +391,7 @@ def with_config(self, scope: ContextScope | None, strict_scope: bool = False) ->
if strict_scope and scope == ContextScopes.ANY:
msg = f"Cannot set strict_scope with scope {scope}."
raise ValueError(msg)
- r = ContextResource(self._from_creator, *self._args, **self._kwargs) # type: ignore[arg-type]
+ r = ContextResource(self._from_creator, *self._args, **self._kwargs)
r._scope = scope
r._strict_scope = strict_scope
@@ -350,24 +407,48 @@ def _enter_context_sync(self, force: bool = False) -> ResourceContext[T_co]:
raise RuntimeError(msg)
return self._enter(force)
+ def _enter_injection_context_sync(
+ self,
+ force: bool = False,
+ ) -> tuple[ResourceContext[T_co], _SyncInjectionExitState[T_co]]:
+ if self._is_async:
+ msg = "Please use async context instead."
+ raise RuntimeError(msg)
+ if not force and self._scope != ContextScopes.ANY:
+ current_scope = get_current_scope()
+ if self._scope != current_scope:
+ msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}"
+ raise InvalidContextError(msg)
+
+ context_item: ResourceContext[T_co] = ResourceContext(is_async=False)
+ token = self._context.set(context_item)
+ return context_item, _SyncInjectionExitState(self._context, context_item, token)
+
async def _enter_context_async(self, force: bool = False) -> ResourceContext[T_co]:
return self._enter(force)
def _enter(self, force: bool = False) -> ResourceContext[T_co]:
- if not force and self._scope not in (ContextScopes.ANY, get_current_scope()):
- msg = f"Cannot enter context for resource with scope {self._scope} in scope {get_current_scope()!r}"
- raise InvalidContextError(msg)
- self._token = self._context.set(ResourceContext(is_async=self._is_async))
- return self._context.get()
+ if not force and self._scope != ContextScopes.ANY:
+ current_scope = get_current_scope()
+ if self._scope != current_scope:
+ msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}"
+ raise InvalidContextError(msg)
+ context_item: ResourceContext[T_co] = ResourceContext(is_async=self._is_async)
+ self._token = self._context.set(context_item)
+ return context_item
def _exit_context_sync(self) -> None:
- if not self._token:
+ if self._token is None:
msg = "Context is not set, call ``_enter_sync_context`` first"
raise RuntimeError(msg)
try:
context_item = self._context.get()
- context_item.tear_down_sync()
+ context_stack = context_item.context_stack
+ if context_item.is_context_stack_sync(context_stack):
+ context_stack.close() # type: ignore[union-attr]
+ context_item.context_stack = UNSET
+ context_item.instance = UNSET
finally:
self._context.reset(self._token)
@@ -385,21 +466,9 @@ async def _exit_context_async(self) -> None:
finally:
self._context.reset(self._token)
- @contextlib.contextmanager
@override
- def context_sync(self, force: bool = False) -> typing.Iterator[ResourceContext[T_co]]:
- if self._is_async:
- msg = "Please use async context instead."
- raise RuntimeError(msg)
- token = self._token
- with self._lock:
- val = self._enter_context_sync(force=force)
- temp_token = self._token
- yield val
- with self._lock:
- self._token = temp_token
- self._exit_context_sync()
- self._token = token
+ def context_sync(self, force: bool = False) -> _SyncContextResourceContext[T_co]:
+ return _SyncContextResourceContext(self, force)
@contextlib.asynccontextmanager
@override
@@ -423,9 +492,6 @@ def _fetch_context(self) -> ResourceContext[T_co]:
raise RuntimeError(msg) from e
-ContainerType = typing.TypeVar("ContainerType", bound="type[BaseContainer]")
-
-
class container_context(AbstractContextManager[ContextType], AbstractAsyncContextManager[ContextType]): # noqa: N801
"""Initialize contexts for the provided containers or resources.
@@ -434,15 +500,16 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex
"""
__slots__ = (
- "_containers",
+ "_container_items",
+ "_container_providers_by_scope",
"_context_items",
- "_context_providers",
"_context_stack",
"_context_token",
+ "_direct_context_items",
+ "_entered_context_items",
"_global_context",
"_initial_context",
"_preserve_global_context",
- "_providers",
"_reset_resource_context",
"_scope",
"_scope_token",
@@ -458,7 +525,7 @@ def __init__(
"""Initialize a new container context.
Args:
- *context_items (SupportsContext[Any]): Context items to initialize a new context for.
+ *context_items: Context-capable providers or container classes to initialize.
global_context (dict[str, Any] | None): A dictionary representing the global context.
preserve_global_context (bool): If True, merges the existing global context with the new one.
scope (ContextScope | None): The named scope that should be initialized.
@@ -481,13 +548,44 @@ def __init__(
self._global_context = global_context
self._context_token: Token[ContextType] | None = None
self._context_items: typing.Final[set[SupportsContext[typing.Any]]] = set(context_items)
- self._context_providers: set[ContextResource[typing.Any]] = set()
+ self._container_items: tuple[type[BaseContainer], ...]
+ self._direct_context_items: tuple[SupportsContext[typing.Any], ...]
+ self._container_providers_by_scope: dict[ContextScope | None, tuple[ContextResource[typing.Any], ...]] = {}
+ self._entered_context_items: tuple[SupportsContext[typing.Any], ...] = ()
self._reset_resource_context: typing.Final[bool] = bool(scope)
self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None
self._scope_token: Token[ContextScope | None] | None = None
+ self._container_items, self._direct_context_items = self._parse_context_items(self._context_items)
+
+ def _parse_context_items(
+ self,
+ context_items: set[SupportsContext[typing.Any]],
+ ) -> tuple[tuple[type["BaseContainer"], ...], tuple[SupportsContext[typing.Any], ...]]:
+ from that_depends.container import BaseContainer # noqa: PLC0415
+
+ containers: list[type[BaseContainer]] = []
+ direct_items: list[SupportsContext[typing.Any]] = []
+ for item in context_items:
+ if isinstance(item, type) and issubclass(item, BaseContainer):
+ containers.append(item)
+ else:
+ direct_items.append(item)
+ return tuple(containers), tuple(direct_items)
- def _resolve_initial_conditions(self) -> None:
- self._scope = self._scope or get_current_scope()
+ def _get_context_providers_for_scope(
+ self,
+ scope: ContextScope | None,
+ ) -> tuple[ContextResource[typing.Any], ...]:
+ cached = self._container_providers_by_scope.get(scope)
+ if cached is None:
+ providers: set[ContextResource[typing.Any]] = set()
+ self._add_providers_from_containers(self._container_items, providers, scope)
+ cached = tuple(providers)
+ self._container_providers_by_scope[scope] = cached
+ return cached
+
+ def _resolve_initial_conditions(self) -> ContextScope | None:
+ scope = self._scope or get_current_scope()
if self._preserve_global_context and self._global_context:
if context := _get_container_context():
self._initial_context = {**context, **self._global_context}
@@ -499,44 +597,48 @@ def _resolve_initial_conditions(self) -> None:
)
else:
self._initial_context = self._global_context or {}
+ entered_context_items = dict.fromkeys(self._direct_context_items)
+ for provider in self._get_context_providers_for_scope(scope):
+ entered_context_items[provider] = provider
if self._reset_resource_context:
from that_depends.meta import BaseContainerMeta # noqa: PLC0415
- self._add_providers_from_containers(BaseContainerMeta.get_instances().values(), self._scope)
- for item in self._context_items:
- from that_depends.container import BaseContainer # noqa: PLC0415
-
- if isinstance(item, type) and issubclass(item, BaseContainer):
- self._add_providers_from_containers([item], self._scope)
- elif isinstance(item, ContextResource):
- self._context_providers.add(item)
+ scope_providers: set[ContextResource[typing.Any]] = set()
+ self._add_providers_from_containers(BaseContainerMeta.get_instances().values(), scope_providers, scope)
+ for provider in scope_providers:
+ entered_context_items[provider] = provider
+ self._entered_context_items = tuple(entered_context_items)
+ return scope
def _add_providers_from_containers(
- self, containers: Iterable[ContainerType], scope: ContextScope | None = ContextScopes.ANY
+ self,
+ containers: Iterable[BaseContainerType],
+ target: set[ContextResource[typing.Any]],
+ scope: ContextScope | None = ContextScopes.ANY,
) -> None:
for container in containers:
for container_provider in container.get_providers().values():
if isinstance(container_provider, ContextResource):
provider_scope = container_provider.get_scope()
if provider_scope in (scope, ContextScopes.ANY):
- self._context_providers.add(container_provider)
+ target.add(container_provider)
@override
def __enter__(self) -> ContextType:
- self._resolve_initial_conditions()
+ scope = self._resolve_initial_conditions()
self._context_stack = contextlib.ExitStack()
- self._scope_token = _set_current_scope(self._scope)
- for item in self._context_providers:
+ self._scope_token = _set_current_scope(scope)
+ for item in self._entered_context_items:
if item.supports_context_sync():
self._context_stack.enter_context(item.context_sync())
return self._enter_globals()
@override
async def __aenter__(self) -> ContextType:
- self._resolve_initial_conditions()
+ scope = self._resolve_initial_conditions()
self._context_stack = contextlib.AsyncExitStack()
- self._scope_token = _set_current_scope(self._scope)
- for item in self._context_providers:
+ self._scope_token = _set_current_scope(scope)
+ for item in self._entered_context_items:
await self._context_stack.enter_async_context(item.context_async())
return self._enter_globals()
@@ -667,8 +769,7 @@ def __init__(
Args:
app (ASGIApp): The ASGI application to wrap.
- *context_items (SupportsContext[Any]): A collection of containers and providers that
- need context initialization prior to a request.
+ *context_items: Containers and providers that need context initialization prior to a request.
global_context (dict[str, Any] | None): A global context dictionary to set before requests.
scope (ContextScope | None): The scope in which the context should be initialized.
diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py
index 6242ea5f..9102d0da 100644
--- a/that_depends/providers/factories.py
+++ b/that_depends/providers/factories.py
@@ -5,8 +5,15 @@
from typing_extensions import override
-from that_depends.providers.base import AbstractProvider
+from that_depends.providers.base import (
+ AbstractProvider,
+ _resolve_arguments,
+ _resolve_arguments_sync,
+ _resolve_keyword_arguments,
+ _resolve_keyword_arguments_sync,
+)
from that_depends.providers.mixin import ProviderWithArguments
+from that_depends.utils import is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -76,13 +83,23 @@ def build_resource(text: str, number: int):
"""
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
self._register(self._args)
self._register(self._kwargs.values())
def _deregister_arguments(self) -> None:
raise NotImplementedError
- __slots__ = "_args", "_factory", "_kwargs", "_override"
+ __slots__ = (
+ "_args",
+ "_args_are_providers",
+ "_factory",
+ "_kwargs",
+ "_kwargs_are_providers",
+ "_kwargs_items",
+ "_override",
+ )
def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None:
"""Initialize a Factory instance.
@@ -94,37 +111,34 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P
"""
super().__init__()
- self._factory: typing.Final = factory
+ self._factory: typing.Final[typing.Callable[..., T_co]] = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
+ self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args)
+ self._kwargs_items: typing.Final = tuple(kwargs.items())
+ self._kwargs_are_providers: typing.Final = tuple(
+ isinstance(value, AbstractProvider) for _, value in self._kwargs_items
+ )
self._register_arguments()
@override
async def resolve(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
return self._factory(
- *[ # type: ignore[arg-type]
- await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args
- ],
- **{ # type: ignore[arg-type]
- k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
- },
+ *await _resolve_arguments(self._args, self._args_are_providers),
+ **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers),
)
@override
def resolve_sync(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
return self._factory(
- *[ # type: ignore[arg-type]
- x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args
- ],
- **{ # type: ignore[arg-type]
- k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
- },
+ *_resolve_arguments_sync(self._args, self._args_are_providers),
+ **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers),
)
@@ -147,13 +161,23 @@ async def async_build_resource(text: str):
"""
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
self._register(self._args)
self._register(self._kwargs.values())
def _deregister_arguments(self) -> None:
raise NotImplementedError
- __slots__ = "_args", "_factory", "_kwargs", "_override"
+ __slots__ = (
+ "_args",
+ "_args_are_providers",
+ "_factory",
+ "_kwargs",
+ "_kwargs_are_providers",
+ "_kwargs_items",
+ "_override",
+ )
@overload
def __init__(
@@ -176,22 +200,27 @@ def __init__(
"""
super().__init__()
- self._factory: typing.Final = factory
+ self._factory: typing.Final[typing.Callable[..., T_co | typing.Awaitable[T_co]]] = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
+ self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args)
+ self._kwargs_items: typing.Final = tuple(kwargs.items())
+ self._kwargs_are_providers: typing.Final = tuple(
+ isinstance(value, AbstractProvider) for _, value in self._kwargs_items
+ )
self._register_arguments()
@override
async def resolve(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
- args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args]
- kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}
+ args = await _resolve_arguments(self._args, self._args_are_providers)
+ kwargs = await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers)
result = self._factory(
- *args, # type:ignore[arg-type]
- **kwargs, # type:ignore[arg-type]
+ *args,
+ **kwargs,
)
if inspect.isawaitable(result):
diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py
index dd8f1ef0..20228f11 100644
--- a/that_depends/providers/local_singleton.py
+++ b/that_depends/providers/local_singleton.py
@@ -5,7 +5,14 @@
from typing_extensions import override
from that_depends.providers import AbstractProvider
+from that_depends.providers.base import (
+ _resolve_arguments,
+ _resolve_arguments_sync,
+ _resolve_keyword_arguments,
+ _resolve_keyword_arguments_sync,
+)
from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown
+from that_depends.utils import UNSET, Unset, is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -50,67 +57,77 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P
"""
super().__init__()
- self._factory: typing.Final = factory
+ self._factory: typing.Final[typing.Callable[..., T_co]] = factory
self._thread_local = threading.local()
self._asyncio_lock = asyncio.Lock()
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
+ self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args)
+ self._kwargs_items: typing.Final = tuple(kwargs.items())
+ self._kwargs_are_providers: typing.Final = tuple(
+ isinstance(value, AbstractProvider) for _, value in self._kwargs_items
+ )
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
self._register(self._args)
self._register(self._kwargs.values())
def _deregister_arguments(self) -> None:
self._deregister(self._args)
self._deregister(self._kwargs.values())
+ self._reset_arguments_registration()
@property
- def _instance(self) -> T_co | None:
- return getattr(self._thread_local, "instance", None)
+ def _instance(self) -> T_co | Unset:
+ return typing.cast(T_co | Unset, getattr(self._thread_local, "instance", UNSET))
@_instance.setter
- def _instance(self, value: T_co | None) -> None:
+ def _instance(self, value: T_co | Unset) -> None:
self._thread_local.instance = value
@override
async def resolve(self) -> T_co:
- if self._override is not None:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
+ if is_set(self._instance):
+ return self._instance
async with self._asyncio_lock:
- if self._instance is not None:
+ if is_set(self._instance):
return self._instance
self._register_arguments()
- self._instance = self._factory(
- *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
- **{ # type: ignore[arg-type]
- k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
- },
+ instance = self._factory(
+ *await _resolve_arguments(self._args, self._args_are_providers),
+ **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers),
)
- return self._instance
+ self._instance = instance
+ return instance
@override
def resolve_sync(self) -> T_co:
- if self._override is not None:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
- if self._instance is not None:
+ if is_set(self._instance):
return self._instance
self._register_arguments()
- self._instance = self._factory(
- *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
- **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type]
+ instance = self._factory(
+ *_resolve_arguments_sync(self._args, self._args_are_providers),
+ **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers),
)
- return self._instance
+ self._instance = instance
+ return instance
@override
def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
- if self._instance is not None:
- self._instance = None
+ if is_set(self._instance):
+ self._instance = UNSET
self._deregister_arguments()
if propagate:
self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async)
@@ -122,8 +139,8 @@ async def tear_down(self, propagate: bool = True) -> None:
After calling this method, subsequent calls to `resolve_sync()` on the
same thread will produce a new instance.
"""
- if self._instance is not None:
- self._instance = None
+ if is_set(self._instance):
+ self._instance = UNSET
self._deregister_arguments()
if propagate:
await self._tear_down_children()
diff --git a/that_depends/providers/mixin.py b/that_depends/providers/mixin.py
index 5462230b..99597af0 100644
--- a/that_depends/providers/mixin.py
+++ b/that_depends/providers/mixin.py
@@ -26,6 +26,22 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) ->
class ProviderWithArguments(abc.ABC):
"""Interface for providers that require arguments."""
+ __slots__ = ("_arguments_registered",)
+
+ def __init__(self) -> None:
+ """Initialize provider argument registration state."""
+ super().__init__()
+ self._arguments_registered = False
+
+ def _mark_arguments_registered(self) -> bool:
+ if self._arguments_registered:
+ return False
+ self._arguments_registered = True
+ return True
+
+ def _reset_arguments_registration(self) -> None:
+ self._arguments_registered = False
+
@abc.abstractmethod
def _register_arguments(self) -> None:
"""Register arguments for the provider."""
diff --git a/that_depends/providers/object.py b/that_depends/providers/object.py
index b29ed7cd..c0843eae 100644
--- a/that_depends/providers/object.py
+++ b/that_depends/providers/object.py
@@ -3,6 +3,7 @@
from typing_extensions import override
from that_depends.providers.base import AbstractProvider
+from that_depends.utils import is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -42,6 +43,6 @@ async def resolve(self) -> T_co:
@override
def resolve_sync(self) -> T_co:
- if self._override is not None:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
return self._obj
diff --git a/that_depends/providers/selector.py b/that_depends/providers/selector.py
index 5bbe023f..9ec2c48b 100644
--- a/that_depends/providers/selector.py
+++ b/that_depends/providers/selector.py
@@ -5,6 +5,7 @@
from typing_extensions import override
from that_depends.providers.base import AbstractProvider
+from that_depends.utils import is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -69,7 +70,7 @@ def my_selector():
@override
async def resolve(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
if isinstance(self._selector, AbstractProvider):
@@ -83,7 +84,7 @@ async def resolve(self) -> T_co:
@override
def resolve_sync(self) -> T_co:
- if self._override:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
if isinstance(self._selector, AbstractProvider):
diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py
index daa470d6..961f77ab 100644
--- a/that_depends/providers/singleton.py
+++ b/that_depends/providers/singleton.py
@@ -6,8 +6,15 @@
from typing_extensions import override
-from that_depends.providers.base import AbstractProvider
+from that_depends.providers.base import (
+ AbstractProvider,
+ _resolve_arguments,
+ _resolve_arguments_sync,
+ _resolve_keyword_arguments,
+ _resolve_keyword_arguments_sync,
+)
from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown
+from that_depends.utils import UNSET, Unset, is_set
T_co = typing.TypeVar("T_co", covariant=True)
@@ -34,7 +41,18 @@ def my_factory() -> float:
"""
- __slots__ = "_args", "_asyncio_lock", "_factory", "_instance", "_kwargs", "_override", "_threading_lock"
+ __slots__ = (
+ "_args",
+ "_args_are_providers",
+ "_asyncio_lock",
+ "_factory",
+ "_instance",
+ "_kwargs",
+ "_kwargs_are_providers",
+ "_kwargs_items",
+ "_override",
+ "_threading_lock",
+ )
def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None:
"""Initialize the Singleton provider.
@@ -46,54 +64,64 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P
"""
super().__init__()
- self._factory: typing.Final = factory
- self._instance: T_co | None = None
+ self._factory: typing.Final[typing.Callable[..., T_co]] = factory
+ self._instance: T_co | Unset = UNSET
self._asyncio_lock: typing.Final = asyncio.Lock()
self._threading_lock: typing.Final = threading.Lock()
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
+ self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args)
+ self._kwargs_items: typing.Final = tuple(kwargs.items())
+ self._kwargs_are_providers: typing.Final = tuple(
+ isinstance(value, AbstractProvider) for _, value in self._kwargs_items
+ )
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
self._register(self._args)
self._register(self._kwargs.values())
def _deregister_arguments(self) -> None:
self._deregister(self._args)
self._deregister(self._kwargs.values())
+ self._reset_arguments_registration()
@override
async def resolve(self) -> T_co:
- if self._override is not None:
+ if is_set(self._override):
self._register_arguments()
return typing.cast(T_co, self._override)
+ if is_set(self._instance):
+ return self._instance
# lock to prevent resolving several times
async with self._asyncio_lock:
- if self._instance is not None:
+ if is_set(self._instance):
return self._instance
self._register_arguments()
self._instance = self._factory(
- *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
- **{ # type: ignore[arg-type]
- k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
- },
+ *await _resolve_arguments(self._args, self._args_are_providers),
+ **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers),
)
return self._instance
@override
def resolve_sync(self) -> T_co:
- if self._override is not None:
+ if is_set(self._override):
self._register_arguments()
return typing.cast(T_co, self._override)
+ if is_set(self._instance):
+ return self._instance
# lock to prevent resolving several times
with self._threading_lock:
- if self._instance is not None:
+ if is_set(self._instance):
return self._instance
self._register_arguments()
self._instance = self._factory(
- *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
- **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type]
+ *_resolve_arguments_sync(self._args, self._args_are_providers),
+ **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers),
)
return self._instance
@@ -103,8 +131,8 @@ async def tear_down(self, propagate: bool = True) -> None:
After calling this method, the next resolve() call will recreate the instance.
"""
- if self._instance is not None:
- self._instance = None
+ if is_set(self._instance):
+ self._instance = UNSET
self._deregister_arguments()
if propagate:
await self._tear_down_children()
@@ -115,8 +143,8 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) ->
After calling this method, the next resolve call will recreate the instance.
"""
- if self._instance is not None:
- self._instance = None
+ if is_set(self._instance):
+ self._instance = UNSET
self._deregister_arguments()
if propagate:
self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async)
@@ -142,7 +170,17 @@ async def my_async_factory() -> float:
"""
- __slots__ = "_args", "_asyncio_lock", "_factory", "_instance", "_kwargs", "_override"
+ __slots__ = (
+ "_args",
+ "_args_are_providers",
+ "_asyncio_lock",
+ "_factory",
+ "_instance",
+ "_kwargs",
+ "_kwargs_are_providers",
+ "_kwargs_items",
+ "_override",
+ )
def __init__(
self,
@@ -159,37 +197,45 @@ def __init__(
"""
super().__init__()
- self._factory: typing.Final[typing.Callable[P, typing.Awaitable[T_co]]] = factory
- self._instance: T_co | None = None
+ self._factory: typing.Final[typing.Callable[..., typing.Awaitable[T_co]]] = factory
+ self._instance: T_co | Unset = UNSET
self._asyncio_lock: typing.Final = asyncio.Lock()
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
+ self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args)
+ self._kwargs_items: typing.Final = tuple(kwargs.items())
+ self._kwargs_are_providers: typing.Final = tuple(
+ isinstance(value, AbstractProvider) for _, value in self._kwargs_items
+ )
def _register_arguments(self) -> None:
+ if not self._mark_arguments_registered():
+ return
self._register(self._args)
self._register(self._kwargs.values())
def _deregister_arguments(self) -> None:
self._deregister(self._args)
self._deregister(self._kwargs.values())
+ self._reset_arguments_registration()
@override
async def resolve(self) -> T_co:
- if self._override is not None:
+ if is_set(self._override):
return typing.cast(T_co, self._override)
+ if is_set(self._instance):
+ return self._instance
# lock to prevent resolving several times
async with self._asyncio_lock:
- if self._instance is not None:
+ if is_set(self._instance):
return self._instance
self._register_arguments()
self._instance = await self._factory(
- *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
- **{ # type: ignore[arg-type]
- k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
- },
+ *await _resolve_arguments(self._args, self._args_are_providers),
+ **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers),
)
return self._instance
@@ -204,8 +250,8 @@ async def tear_down(self, propagate: bool = True) -> None:
After calling this method, the next call to ``resolve()`` will recreate the instance.
"""
- if self._instance is not None:
- self._instance = None
+ if is_set(self._instance):
+ self._instance = UNSET
self._deregister_arguments()
if propagate:
await self._tear_down_children()
@@ -216,8 +262,8 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) ->
After calling this method, the next call to ``resolve_sync()`` will recreate the instance.
"""
- if self._instance is not None:
- self._instance = None
+ if is_set(self._instance):
+ self._instance = UNSET
self._deregister_arguments()
if propagate:
self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async)