From 845ad9e79949c550018ea46b487cece6eaf4d04c Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 15 Mar 2026 00:10:56 +0100 Subject: [PATCH 01/13] types: correct type for resource creator instead of typing.Any. --- that_depends/providers/base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 9c30bdf..f46901f 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -306,7 +306,9 @@ def __init__( """ super().__init__() - self._creator: typing.Any + self._creator: ( + typing.Callable[P, typing.ContextManager[T_co]] | typing.Callable[P, typing.AsyncContextManager[T_co]] + ) if inspect.isasyncgenfunction(creator): self._is_async = True @@ -352,8 +354,8 @@ async def resolve(self) -> T_co: self._register_arguments() cm: typing.ContextManager[T_co] | typing.AsyncContextManager[T_co] = self._creator( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], - **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, + *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type:ignore[arg-type] + **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type:ignore[arg-type] ) if isinstance(cm, typing.AsyncContextManager): @@ -388,11 +390,11 @@ def resolve_sync(self) -> T_co: self._register_arguments() cm = self._creator( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, + *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type:ignore[arg-type] + **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type:ignore[arg-type] ) context.context_stack = contextlib.ExitStack() - context.instance = context.context_stack.enter_context(cm) + context.instance = context.context_stack.enter_context(cm) # type:ignore[arg-type] return context.instance From 56fcf79d79364a2d380268ecfd2fdfa2dd111e12 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Apr 2026 15:34:32 +0200 Subject: [PATCH 02/13] feat: performance improvements. --- that_depends/injection.py | 353 ++++++++++++-------- that_depends/meta.py | 7 + that_depends/providers/base.py | 120 ++++++- that_depends/providers/context_resources.py | 159 +++++++-- that_depends/providers/factories.py | 66 ++-- that_depends/providers/local_singleton.py | 24 +- that_depends/providers/mixin.py | 16 + that_depends/providers/singleton.py | 65 +++- 8 files changed, 593 insertions(+), 217 deletions(-) diff --git a/that_depends/injection.py b/that_depends/injection.py index 3a48dba..1a74846 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -1,19 +1,18 @@ -import asyncio -import contextlib import functools import inspect import re -import threading import typing import warnings -from contextlib import AsyncExitStack, ExitStack +from contextlib import AsyncExitStack +from types import TracebackType + +from typing_extensions import Self from that_depends.container import BaseContainer from that_depends.exceptions import TypeNotBoundError from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, ContextResource from that_depends.providers.context_resources import ContextScope, ContextScopes, container_context -from that_depends.providers.mixin import ProviderWithArguments class ContextProviderError(Exception): @@ -27,6 +26,91 @@ class ContextProviderError(Exception): _PROVIDE_MESSAGE: typing.Final[str] = ( "Use @Container.inject or @inject(container=Container) if you wish to use Provide()" ) +_INJECT_DIRECT_PROVIDER: typing.Final = 1 +_INJECT_STRING_PROVIDER: typing.Final = 2 +_INJECT_TYPED_PROVIDER: typing.Final = 3 + + +class _InjectionParameter(typing.NamedTuple): + argument_index: int + field_name: str + kind: int + dependency: typing.Any + + +class _DirectInjectionParameter(typing.NamedTuple): + argument_index: int + field_name: str + provider: AbstractProvider[typing.Any] + scope_init_order: tuple[AbstractProvider[typing.Any], ...] + + +class _InjectionPlan(typing.NamedTuple): + direct_parameters: tuple[_DirectInjectionParameter, ...] + dynamic_parameters: tuple[_InjectionParameter, ...] + + +class _SyncInjectionStack: + __slots__ = ("_exit_states",) + + def __init__(self) -> None: + self._exit_states: list[_SupportsClose] = [] + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> typing.Literal[False]: + _ = exc_type, exc_value, traceback + while self._exit_states: + self._exit_states.pop().close() + return False + + def enter_context(self, context_manager: typing.ContextManager[T]) -> T: + value = context_manager.__enter__() + self._exit_states.append(_ContextManagerExitState(context_manager)) + return value + + def push_exit_state(self, exit_state: "_SupportsClose") -> None: + self._exit_states.append(exit_state) + + +class _SupportsClose(typing.Protocol): + def close(self) -> None: ... + + +class _ContextManagerExitState: + __slots__ = ("_context_manager",) + + def __init__(self, context_manager: typing.ContextManager[typing.Any]) -> None: + self._context_manager = context_manager + + def close(self) -> None: + self._context_manager.__exit__(None, None, None) + + +@functools.cache +def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionPlan: + direct_parameters: list[_DirectInjectionParameter] = [] + dynamic_parameters: list[_InjectionParameter] = [] + for index, (field_name, param) in enumerate(inspect.signature(func).parameters.items()): + default = param.default + if isinstance(default, StringProviderDefinition): + dynamic_parameters.append(_InjectionParameter(index, field_name, _INJECT_STRING_PROVIDER, default)) + elif isinstance(default, AbstractProvider): + direct_parameters.append( + _DirectInjectionParameter(index, field_name, default, default._get_scope_init_order()) # noqa: SLF001 + ) + elif isinstance(default, _Provide): + dynamic_parameters.append(_InjectionParameter(index, field_name, _INJECT_TYPED_PROVIDER, param.annotation)) + return _InjectionPlan( + direct_parameters=tuple(direct_parameters), + dynamic_parameters=tuple(dynamic_parameters), + ) @typing.overload @@ -88,10 +172,11 @@ def _inject( def _inject_to_sync_gen( gen: typing.Callable[P, typing.Generator[T, typing.Any, typing.Any]], ) -> typing.Callable[P, typing.Generator[T, typing.Any, typing.Any]]: + plan = _build_injection_plan(gen) + @functools.wraps(gen) def inner(*args: P.args, **kwargs: P.kwargs) -> typing.Generator[T, typing.Any, typing.Any]: - signature = inspect.signature(gen) - injected, kwargs = _resolve_arguments_sync(signature, scope, container, None, *args, **kwargs) # type: ignore[assignment] + injected, kwargs = _resolve_arguments_sync(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=2) @@ -105,11 +190,11 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> typing.Generator[T, typing.Any, def _inject_to_async_gen( gen: typing.Callable[P, typing.AsyncGenerator[T, typing.Any]], ) -> typing.Callable[P, typing.AsyncGenerator[T, typing.Any]]: + plan = _build_injection_plan(gen) + @functools.wraps(gen) async def inner(*args: P.args, **kwargs: P.kwargs) -> typing.AsyncGenerator[T, typing.Any]: - signature = inspect.signature(gen) - - injected, kwargs = await _resolve_arguments_async(signature, scope, container, None, *args, **kwargs) # type: ignore[assignment] + injected, kwargs = await _resolve_arguments_async(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=1) @@ -122,26 +207,30 @@ async def inner(*args: P.args, **kwargs: P.kwargs) -> typing.AsyncGenerator[T, t def _inject_to_async( func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], ) -> typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]]: + plan = _build_injection_plan(func) + @functools.wraps(func) async def inner(*args: P.args, **kwargs: P.kwargs) -> T: if enter_scope: async with container_context(scope=scope): - return await _resolve_async(func, None, container, *args, **kwargs) + return await _resolve_async(func, plan, None, container, *args, **kwargs) else: - return await _resolve_async(func, scope, container, *args, **kwargs) + return await _resolve_async(func, plan, scope, container, *args, **kwargs) return inner def _inject_to_sync( func: typing.Callable[P, T], ) -> typing.Callable[P, T]: + plan = _build_injection_plan(func) + @functools.wraps(func) def inner(*args: P.args, **kwargs: P.kwargs) -> T: if enter_scope: with container_context(scope=scope): - return _resolve_sync(func, None, container, *args, **kwargs) + return _resolve_sync(func, plan, None, container, *args, **kwargs) else: - return _resolve_sync(func, scope, container, *args, **kwargs) + return _resolve_sync(func, plan, scope, container, *args, **kwargs) return inner @@ -150,131 +239,128 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T: return _inject(func) -_SYNC_SIGNATURE_CACHE: dict[typing.Callable[..., typing.Any], inspect.Signature] = {} -_THREADING_LOCK = threading.Lock() - - async def _resolve_arguments_async( - signature: inspect.Signature, + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None, stack: AsyncExitStack | None, *args: typing.Any, # noqa: ANN401 **kwargs: typing.Any, # noqa: ANN401 ) -> tuple[bool, dict[str, typing.Any]]: - injected = False + if not plan.direct_parameters and not plan.dynamic_parameters: + return False, kwargs + context_providers: set[AbstractProvider[typing.Any]] = set() - params = list(signature.parameters.items()) + for direct_parameter in plan.direct_parameters: + if direct_parameter.argument_index < len(args) or direct_parameter.field_name in kwargs: + continue - for i, (field_name, param) in enumerate(params): - default = param.default + await _setup_scope_contexts_async(direct_parameter.scope_init_order, scope, stack, context_providers) + kwargs[direct_parameter.field_name] = await direct_parameter.provider.resolve() - if i < len(args) or field_name in kwargs: - if isinstance(default, (AbstractProvider, StringProviderDefinition)): - injected = True + for dynamic_parameter in plan.dynamic_parameters: + if dynamic_parameter.argument_index < len(args) or dynamic_parameter.field_name in kwargs: continue - if isinstance(default, StringProviderDefinition): - injected = True - resolved_val = await _resolve_provider_with_scope_async(default.provider, scope, stack, context_providers) - kwargs[field_name] = resolved_val - elif isinstance(default, AbstractProvider): - injected = True - resolved_val = await _resolve_provider_with_scope_async(default, scope, stack, context_providers) - kwargs[field_name] = resolved_val - - elif isinstance(default, _Provide): - injected = True - if container is None: - raise RuntimeError(_PROVIDE_MESSAGE) - try: - provider = container.get_provider_for_type(signature.parameters[field_name].annotation) - except TypeNotBoundError as e: - msg = f"Type {signature.parameters[field_name].annotation} is not bound to a provider." - raise RuntimeError(msg) from e - kwargs[field_name] = await _resolve_provider_with_scope_async(provider, scope, stack, context_providers) - return injected, kwargs + provider = _resolve_injected_provider(dynamic_parameter, container) + kwargs[dynamic_parameter.field_name] = await _resolve_provider_with_scope_async( + provider, + scope, + stack, + context_providers, + ) + return True, kwargs def _resolve_arguments_sync( - signature: inspect.Signature, + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None, - stack: contextlib.ExitStack | None, + stack: _SyncInjectionStack | None, *args: typing.Any, # noqa: ANN401 **kwargs: typing.Any, # noqa: ANN401 ) -> tuple[bool, dict[str, typing.Any]]: - injected = False + if not plan.direct_parameters and not plan.dynamic_parameters: + return False, kwargs + context_providers: set[AbstractProvider[typing.Any]] = set() - for i, (field_name, param) in enumerate(signature.parameters.items()): - default = param.default - if i < len(args) or field_name in kwargs: - if isinstance(default, (AbstractProvider, StringProviderDefinition, _Provide)): - injected = True + for direct_parameter in plan.direct_parameters: + if direct_parameter.argument_index < len(args) or direct_parameter.field_name in kwargs: continue - if isinstance(default, StringProviderDefinition): - injected = True - kwargs[field_name] = _resolve_provider_with_scope_sync(default.provider, scope, stack, context_providers) - elif isinstance(default, AbstractProvider): - injected = True - kwargs[field_name] = _resolve_provider_with_scope_sync(default, scope, stack, context_providers) - elif isinstance(default, _Provide): - injected = True - if container is None: - raise RuntimeError(_PROVIDE_MESSAGE) - try: - provider = container.get_provider_for_type(signature.parameters[field_name].annotation) - except TypeNotBoundError as e: - msg = f"Type {signature.parameters[field_name].annotation} is not bound to a provider." - raise RuntimeError(msg) from e - kwargs[field_name] = _resolve_provider_with_scope_sync(provider, scope, stack, context_providers) + _setup_scope_contexts_sync(direct_parameter.scope_init_order, scope, stack, context_providers) + kwargs[direct_parameter.field_name] = direct_parameter.provider.resolve_sync() + + for dynamic_parameter in plan.dynamic_parameters: + if dynamic_parameter.argument_index < len(args) or dynamic_parameter.field_name in kwargs: + continue + + provider = _resolve_injected_provider(dynamic_parameter, container) + kwargs[dynamic_parameter.field_name] = _resolve_provider_with_scope_sync( + provider, + scope, + stack, + context_providers, + ) - return injected, kwargs + return True, kwargs + + +def _resolve_injected_provider( + parameter: _InjectionParameter, + container: BaseContainerMeta | None, +) -> AbstractProvider[typing.Any]: + if parameter.kind == _INJECT_DIRECT_PROVIDER: + return typing.cast(AbstractProvider[typing.Any], parameter.dependency) + if parameter.kind == _INJECT_STRING_PROVIDER: + string_definition = typing.cast(StringProviderDefinition, parameter.dependency) + return string_definition.provider + if container is None: + raise RuntimeError(_PROVIDE_MESSAGE) + annotation = parameter.dependency + try: + return container.get_provider_for_type(annotation) + except TypeNotBoundError as e: + msg = f"Type {annotation} is not bound to a provider." + raise RuntimeError(msg) from e def _resolve_sync( func: typing.Callable[P, T], + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None = None, *args: P.args, **kwargs: P.kwargs, ) -> T: - if func not in _SYNC_SIGNATURE_CACHE: - with _THREADING_LOCK: - _SYNC_SIGNATURE_CACHE[func] = inspect.signature(func) - signature = _SYNC_SIGNATURE_CACHE[func] - - with ExitStack() as stack: - injected, kwargs = _resolve_arguments_sync(signature, scope, container, stack, *args, **kwargs) # type: ignore[assignment] - + if scope is None: + injected, kwargs = _resolve_arguments_sync(plan, scope, container, None, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=3) return func(*args, **kwargs) + with _SyncInjectionStack() as stack: + injected, kwargs = _resolve_arguments_sync(plan, scope, container, stack, *args, **kwargs) # type: ignore[assignment] -_SIGNATURE_CACHE: dict[ - typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, typing.Any]], inspect.Signature -] = {} + if not injected: + warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=3) -_ASYNCIO_LOCK = asyncio.Lock() + return func(*args, **kwargs) + + raise RuntimeError # pragma: no cover # to prevent mypy issue async def _resolve_async( func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], + plan: _InjectionPlan, scope: ContextScope | None, container: BaseContainerMeta | None = None, *args: P.args, **kwargs: P.kwargs, ) -> T: - if func not in _SIGNATURE_CACHE: - async with _ASYNCIO_LOCK: - _SIGNATURE_CACHE[func] = inspect.signature(func) - signature = _SIGNATURE_CACHE[func] - async with AsyncExitStack() as stack: - injected, kwargs = await _resolve_arguments_async(signature, scope, container, stack, *args, **kwargs) # type: ignore[assignment] + injected, kwargs = await _resolve_arguments_async(plan, scope, container, stack, *args, **kwargs) # type: ignore[assignment] if not injected: warnings.warn(_INJECTION_WARNING_MESSAGE, RuntimeWarning, stacklevel=1) @@ -306,79 +392,68 @@ async def _resolve_provider_with_scope_async( ContextProviderError: if the stack is None. """ - await _add_provider_to_stack_async(provider, stack, scope, providers) + await _setup_scope_contexts_async(provider._get_scope_init_order(), scope, stack, providers) # noqa: SLF001 return await provider.resolve() -async def _add_provider_to_stack_async( - provider: AbstractProvider[T], - stack: AsyncExitStack | None, +async def _setup_scope_contexts_async( + scope_init_order: tuple[AbstractProvider[typing.Any], ...], scope: ContextScope | None, + stack: AsyncExitStack | None, providers: set[AbstractProvider[typing.Any]], ) -> None: - if provider in providers: - return - providers.add(provider) - if not scope: return - if isinstance(provider, ProviderWithArguments): - provider._register_arguments() # noqa: SLF001 - - parents = provider._parents # noqa: SLF001 - for parent in parents: - await _add_provider_to_stack_async(parent, stack, scope, providers) - if isinstance(provider, ContextResource): - provider_scope = provider.get_scope() - if provider_scope in (ContextScopes.ANY, scope): - if stack is None: - msg = ( - f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" - f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." - ) - raise ContextProviderError(msg) - await stack.enter_async_context(provider.context_async(force=True)) + for provider in scope_init_order: + if provider in providers: + continue + providers.add(provider) + if isinstance(provider, ContextResource): + provider_scope = provider.get_scope() + if provider_scope is ContextScopes.ANY or provider_scope is scope: + if stack is None: + msg = ( + f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" + f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." + ) + raise ContextProviderError(msg) + await stack.enter_async_context(provider.context_async(force=True)) def _resolve_provider_with_scope_sync( provider: AbstractProvider[T], scope: ContextScope | None, - stack: ExitStack | None, + stack: _SyncInjectionStack | None, providers: set[AbstractProvider[typing.Any]], ) -> T: - _add_provider_to_stack_sync(provider, stack, scope, providers) + _setup_scope_contexts_sync(provider._get_scope_init_order(), scope, stack, providers) # noqa: SLF001 return provider.resolve_sync() -def _add_provider_to_stack_sync( - provider: AbstractProvider[T], - stack: ExitStack | None, +def _setup_scope_contexts_sync( + scope_init_order: tuple[AbstractProvider[typing.Any], ...], scope: ContextScope | None, + stack: _SyncInjectionStack | None, providers: set[AbstractProvider[typing.Any]], ) -> None: - if provider in providers: - return - providers.add(provider) - if not scope: return - if isinstance(provider, ProviderWithArguments): - provider._register_arguments() # noqa: SLF001 - - parents = provider._parents # noqa: SLF001 - for parent in parents: - _add_provider_to_stack_sync(parent, stack, scope, providers) - - if isinstance(provider, ContextResource): - provider_scope = provider.get_scope() - if provider_scope in (ContextScopes.ANY, scope): - if stack is None: - msg = ( - f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" - f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." - ) - raise ContextProviderError(msg) - stack.enter_context(provider.context_sync(force=True)) + for provider in scope_init_order: + if provider in providers: + continue + providers.add(provider) + + if isinstance(provider, ContextResource): + provider_scope = provider.get_scope() + if provider_scope is ContextScopes.ANY or provider_scope is scope: + if stack is None: + msg = ( + f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" + f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." + ) + raise ContextProviderError(msg) + _, exit_state = provider._enter_injection_context_sync(force=True) # noqa: SLF001 + stack.push_exit_state(exit_state) class StringProviderDefinition: diff --git a/that_depends/meta.py b/that_depends/meta.py index 21507a4..a2bb21a 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -88,6 +88,7 @@ def supports_context_sync(cls) -> bool: "containers", "alias", "default_scope", + "type_provider_cache", ) _lock: Lock = Lock() @@ -163,12 +164,18 @@ def get_provider_for_type(cls, t: type[T]) -> AbstractProvider[T]: Provider for the given type. """ + if not hasattr(cls, "type_provider_cache"): + cls.type_provider_cache: dict[type[typing.Any], AbstractProvider[typing.Any]] = {} + if provider := cls.type_provider_cache.get(t): + return typing.cast(AbstractProvider[T], provider) for provider in cls.get_providers().values(): if provider._has_contravariant_bindings: # noqa: SLF001 for bind in provider._bindings: # noqa: SLF001 if issubclass(bind, t): + cls.type_provider_cache[t] = provider return provider elif t in provider._bindings: # noqa: SLF001 + cls.type_provider_cache[t] = provider return provider msg = f"Type {t} is not bound to any provider in container {cls.name()}" raise TypeNotBoundError(msg) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index f46901f..3b25f62 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -20,6 +20,66 @@ P, typing.Iterator[T_co] | typing.AsyncIterator[T_co] | typing.ContextManager[T_co] | typing.AsyncContextManager[T_co], ] +_EMPTY_ARGS: typing.Final[tuple[()]] = () +_EMPTY_KWARGS: typing.Final[dict[str, typing.Any]] = {} + + +async def _resolve_arguments( + args: tuple[typing.Any, ...], + args_are_providers: tuple[bool, ...], +) -> tuple[typing.Any, ...] | list[typing.Any]: + if not args: + return _EMPTY_ARGS + if len(args) == 1: + arg = args[0] + return (await arg.resolve() if args_are_providers[0] else arg,) + return [ + await arg.resolve() if is_provider else arg for arg, is_provider in zip(args, args_are_providers, strict=False) + ] + + +def _resolve_arguments_sync( + args: tuple[typing.Any, ...], + args_are_providers: tuple[bool, ...], +) -> tuple[typing.Any, ...] | list[typing.Any]: + if not args: + return _EMPTY_ARGS + if len(args) == 1: + arg = args[0] + return (arg.resolve_sync() if args_are_providers[0] else arg,) + return [ + arg.resolve_sync() if is_provider else arg for arg, is_provider in zip(args, args_are_providers, strict=False) + ] + + +async def _resolve_keyword_arguments( + kwargs_items: tuple[tuple[str, typing.Any], ...], + kwargs_are_providers: tuple[bool, ...], +) -> dict[str, typing.Any]: + if not kwargs_items: + return _EMPTY_KWARGS + if len(kwargs_items) == 1: + key, value = kwargs_items[0] + return {key: await value.resolve() if kwargs_are_providers[0] else value} + return { + key: await value.resolve() if is_provider else value + for (key, value), is_provider in zip(kwargs_items, kwargs_are_providers, strict=False) + } + + +def _resolve_keyword_arguments_sync( + kwargs_items: tuple[tuple[str, typing.Any], ...], + kwargs_are_providers: tuple[bool, ...], +) -> dict[str, typing.Any]: + if not kwargs_items: + return _EMPTY_KWARGS + if len(kwargs_items) == 1: + key, value = kwargs_items[0] + return {key: value.resolve_sync() if kwargs_are_providers[0] else value} + return { + key: value.resolve_sync() if is_provider else value + for (key, value), is_provider in zip(kwargs_items, kwargs_are_providers, strict=False) + } class AbstractProvider(abc.ABC, typing.Generic[T_co]): @@ -30,6 +90,7 @@ def __init__(self) -> None: super().__init__() self._children: set[AbstractProvider[typing.Any]] = set() self._parents: set[AbstractProvider[typing.Any]] = set() + self._scope_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None self._override: typing.Any = None self._bindings: set[type] = set() self._has_contravariant_bindings: bool = False @@ -62,10 +123,14 @@ def _register(self, candidates: typing.Iterable[typing.Any]) -> None: None """ + changed = False for candidate in candidates: if isinstance(candidate, AbstractProvider): candidate.add_child_provider(self) self._parents.add(candidate) + changed = True + if changed: + self._invalidate_scope_init_order() def _deregister(self, candidates: typing.Iterable[typing.Any]) -> None: """Deregister current provider as child. @@ -77,10 +142,48 @@ def _deregister(self, candidates: typing.Iterable[typing.Any]) -> None: None """ + changed = False for candidate in candidates: if isinstance(candidate, AbstractProvider) and self in candidate._children: # noqa: SLF001 candidate.remove_child_provider(self) self._parents.discard(candidate) + changed = True + if changed: + self._invalidate_scope_init_order() + + def _invalidate_scope_init_order(self) -> None: + stack = [self] + visited: set[AbstractProvider[typing.Any]] = set() + + while stack: + provider = stack.pop() + if provider in visited: + continue + visited.add(provider) + provider._scope_init_order = None # noqa: SLF001 + stack.extend(provider._children) # noqa: SLF001 + + def _get_scope_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]: + if self._scope_init_order is not None: + return self._scope_init_order + + if isinstance(self, ProviderWithArguments): + self._register_arguments() + + ordered: list[AbstractProvider[typing.Any]] = [] + seen: set[AbstractProvider[typing.Any]] = set() + + for parent in self._parents: + for ancestor in parent._get_scope_init_order(): # noqa: SLF001 + if ancestor not in seen: + seen.add(ancestor) + ordered.append(ancestor) + + if self not in seen: + ordered.append(self) + + self._scope_init_order = tuple(ordered) + return self._scope_init_order def add_child_provider(self, provider: "AbstractProvider[typing.Any]") -> None: """Add a child provider to the current provider. @@ -327,14 +430,20 @@ def __init__( raise TypeError(msg) self._args = args self._kwargs = kwargs + self._args_are_providers = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items = tuple(kwargs.items()) + self._kwargs_are_providers = tuple(isinstance(value, AbstractProvider) for _, value in self._kwargs_items) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @abc.abstractmethod def _fetch_context(self) -> ResourceContext[T_co]: ... @@ -354,8 +463,8 @@ async def resolve(self) -> T_co: self._register_arguments() cm: typing.ContextManager[T_co] | typing.AsyncContextManager[T_co] = self._creator( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type:ignore[arg-type] - **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type:ignore[arg-type] + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) if isinstance(cm, typing.AsyncContextManager): @@ -390,8 +499,8 @@ def resolve_sync(self) -> T_co: self._register_arguments() cm = self._creator( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type:ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type:ignore[arg-type] + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) context.context_stack = contextlib.ExitStack() context.instance = context.context_stack.enter_context(cm) # type:ignore[arg-type] @@ -411,6 +520,8 @@ class AttrGetter( """Provides an attribute after resolving the wrapped provider.""" def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return if isinstance(self._provider, ProviderWithArguments): self._provider._register_arguments() # noqa: SLF001 self._parents = self._provider._parents # noqa: SLF001 @@ -419,6 +530,7 @@ def _deregister_arguments(self) -> None: if isinstance(self._provider, ProviderWithArguments): self._provider._deregister_arguments() # noqa: SLF001 self._parents = self._provider._parents # noqa: SLF001 + self._reset_arguments_registration() __slots__ = "_attrs", "_provider" diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 4dab61a..e815333 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -41,6 +41,53 @@ class InvalidContextError(RuntimeError): """Raised when an invalid context is being used.""" +class _SyncInjectionExitState(typing.Generic[T_co]): + __slots__ = ("_provider", "_temp_token", "_token") + + def __init__( + self, + provider: "ContextResource[T_co]", + token: Token[ResourceContext[T_co]] | None, + temp_token: Token[ResourceContext[T_co]] | None, + ) -> None: + self._provider = provider + self._token = token + self._temp_token = temp_token + + def close(self) -> None: + with self._provider._lock: # noqa: SLF001 + self._provider._token = self._temp_token # noqa: SLF001 + self._provider._exit_context_sync() # noqa: SLF001 + self._provider._token = self._token # noqa: SLF001 + + +class _SyncContextResourceContext(contextlib.ContextDecorator, AbstractContextManager[ResourceContext[T_co]]): + __slots__ = ("_exit_state", "_force", "_provider") + + def __init__(self, provider: "ContextResource[T_co]", force: bool) -> None: + self._provider = provider + self._force = force + self._exit_state: _SyncInjectionExitState[T_co] | None = None + + @override + def __enter__(self) -> ResourceContext[T_co]: + value, self._exit_state = self._provider._enter_injection_context_sync(force=self._force) # noqa: SLF001 + return value + + @override + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if self._exit_state is None: + msg = "Context is not set, call ``__enter__`` first" + raise RuntimeError(msg) + _ = exc_type, exc_value, traceback + self._exit_state.close() + + class ContextScope: """A named context scope.""" @@ -229,7 +276,7 @@ class ContextResource( @override async def resolve(self) -> T_co: current_scope = get_current_scope() - if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope): + if not self._strict_scope or self._scope is ContextScopes.ANY or self._scope is current_scope: return await super().resolve() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @@ -237,7 +284,7 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: current_scope = get_current_scope() - if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope): + if not self._strict_scope or self._scope is ContextScopes.ANY or self._scope is current_scope: return super().resolve_sync() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @@ -350,18 +397,34 @@ def _enter_context_sync(self, force: bool = False) -> ResourceContext[T_co]: raise RuntimeError(msg) return self._enter(force) + def _enter_injection_context_sync( + self, + force: bool = False, + ) -> tuple[ResourceContext[T_co], _SyncInjectionExitState[T_co]]: + if self._is_async: + msg = "Please use async context instead." + raise RuntimeError(msg) + + token = self._token + with self._lock: + value = self._enter_context_sync(force=force) + temp_token = self._token + + return value, _SyncInjectionExitState(self, token, temp_token) + async def _enter_context_async(self, force: bool = False) -> ResourceContext[T_co]: return self._enter(force) def _enter(self, force: bool = False) -> ResourceContext[T_co]: - if not force and self._scope not in (ContextScopes.ANY, get_current_scope()): - msg = f"Cannot enter context for resource with scope {self._scope} in scope {get_current_scope()!r}" + current_scope = get_current_scope() + if not force and self._scope is not ContextScopes.ANY and self._scope is not current_scope: + msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" raise InvalidContextError(msg) self._token = self._context.set(ResourceContext(is_async=self._is_async)) return self._context.get() def _exit_context_sync(self) -> None: - if not self._token: + if self._token is None: msg = "Context is not set, call ``_enter_sync_context`` first" raise RuntimeError(msg) @@ -385,21 +448,9 @@ async def _exit_context_async(self) -> None: finally: self._context.reset(self._token) - @contextlib.contextmanager @override - def context_sync(self, force: bool = False) -> typing.Iterator[ResourceContext[T_co]]: - if self._is_async: - msg = "Please use async context instead." - raise RuntimeError(msg) - token = self._token - with self._lock: - val = self._enter_context_sync(force=force) - temp_token = self._token - yield val - with self._lock: - self._token = temp_token - self._exit_context_sync() - self._token = token + def context_sync(self, force: bool = False) -> _SyncContextResourceContext[T_co]: + return _SyncContextResourceContext(self, force) @contextlib.asynccontextmanager @override @@ -434,11 +485,14 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex """ __slots__ = ( + "_container_items", + "_container_providers_by_scope", "_containers", "_context_items", "_context_providers", "_context_stack", "_context_token", + "_direct_context_providers", "_global_context", "_initial_context", "_preserve_global_context", @@ -481,13 +535,44 @@ def __init__( self._global_context = global_context self._context_token: Token[ContextType] | None = None self._context_items: typing.Final[set[SupportsContext[typing.Any]]] = set(context_items) + self._container_items: tuple[type[BaseContainer], ...] + self._direct_context_providers: tuple[ContextResource[typing.Any], ...] + self._container_providers_by_scope: dict[ContextScope | None, tuple[ContextResource[typing.Any], ...]] = {} self._context_providers: set[ContextResource[typing.Any]] = set() self._reset_resource_context: typing.Final[bool] = bool(scope) self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None self._scope_token: Token[ContextScope | None] | None = None + self._container_items, self._direct_context_providers = self._parse_context_items(self._context_items) + + def _parse_context_items( + self, + context_items: set[SupportsContext[typing.Any]], + ) -> tuple[tuple[type["BaseContainer"], ...], tuple[ContextResource[typing.Any], ...]]: + from that_depends.container import BaseContainer # noqa: PLC0415 + + containers: list[type[BaseContainer]] = [] + providers: list[ContextResource[typing.Any]] = [] + for item in context_items: + if isinstance(item, type) and issubclass(item, BaseContainer): + containers.append(item) + elif isinstance(item, ContextResource): + providers.append(item) + return tuple(containers), tuple(providers) - def _resolve_initial_conditions(self) -> None: - self._scope = self._scope or get_current_scope() + def _get_context_providers_for_scope( + self, + scope: ContextScope | None, + ) -> set[ContextResource[typing.Any]]: + cached = self._container_providers_by_scope.get(scope) + if cached is None: + providers = set(self._direct_context_providers) + self._add_providers_from_containers(self._container_items, providers, scope) + cached = tuple(providers) + self._container_providers_by_scope[scope] = cached + return set(cached) + + def _resolve_initial_conditions(self) -> ContextScope | None: + scope = self._scope or get_current_scope() if self._preserve_global_context and self._global_context: if context := _get_container_context(): self._initial_context = {**context, **self._global_context} @@ -499,33 +584,35 @@ def _resolve_initial_conditions(self) -> None: ) else: self._initial_context = self._global_context or {} + self._context_providers = self._get_context_providers_for_scope(scope) if self._reset_resource_context: from that_depends.meta import BaseContainerMeta # noqa: PLC0415 - self._add_providers_from_containers(BaseContainerMeta.get_instances().values(), self._scope) - for item in self._context_items: - from that_depends.container import BaseContainer # noqa: PLC0415 - - if isinstance(item, type) and issubclass(item, BaseContainer): - self._add_providers_from_containers([item], self._scope) - elif isinstance(item, ContextResource): - self._context_providers.add(item) + self._add_providers_from_containers( + BaseContainerMeta.get_instances().values(), + self._context_providers, + scope, + ) + return scope def _add_providers_from_containers( - self, containers: Iterable[ContainerType], scope: ContextScope | None = ContextScopes.ANY + self, + containers: Iterable[ContainerType], + target: set[ContextResource[typing.Any]], + scope: ContextScope | None = ContextScopes.ANY, ) -> None: for container in containers: for container_provider in container.get_providers().values(): if isinstance(container_provider, ContextResource): provider_scope = container_provider.get_scope() - if provider_scope in (scope, ContextScopes.ANY): - self._context_providers.add(container_provider) + if provider_scope is scope or provider_scope is ContextScopes.ANY: + target.add(container_provider) @override def __enter__(self) -> ContextType: - self._resolve_initial_conditions() + scope = self._resolve_initial_conditions() self._context_stack = contextlib.ExitStack() - self._scope_token = _set_current_scope(self._scope) + self._scope_token = _set_current_scope(scope) for item in self._context_providers: if item.supports_context_sync(): self._context_stack.enter_context(item.context_sync()) @@ -533,9 +620,9 @@ def __enter__(self) -> ContextType: @override async def __aenter__(self) -> ContextType: - self._resolve_initial_conditions() + scope = self._resolve_initial_conditions() self._context_stack = contextlib.AsyncExitStack() - self._scope_token = _set_current_scope(self._scope) + self._scope_token = _set_current_scope(scope) for item in self._context_providers: await self._context_stack.enter_async_context(item.context_async()) return self._enter_globals() diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index 6242ea5..8a848a0 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -5,7 +5,13 @@ from typing_extensions import override -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import ( + AbstractProvider, + _resolve_arguments, + _resolve_arguments_sync, + _resolve_keyword_arguments, + _resolve_keyword_arguments_sync, +) from that_depends.providers.mixin import ProviderWithArguments @@ -76,13 +82,23 @@ def build_resource(text: str, number: int): """ def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: raise NotImplementedError - __slots__ = "_args", "_factory", "_kwargs", "_override" + __slots__ = ( + "_args", + "_args_are_providers", + "_factory", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + ) def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: """Initialize a Factory instance. @@ -97,6 +113,11 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) self._register_arguments() @override @@ -105,12 +126,8 @@ async def resolve(self) -> T_co: return typing.cast(T_co, self._override) return self._factory( - *[ # type: ignore[arg-type] - await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) @override @@ -119,12 +136,8 @@ def resolve_sync(self) -> T_co: return typing.cast(T_co, self._override) return self._factory( - *[ # type: ignore[arg-type] - x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) @@ -147,13 +160,23 @@ async def async_build_resource(text: str): """ def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: raise NotImplementedError - __slots__ = "_args", "_factory", "_kwargs", "_override" + __slots__ = ( + "_args", + "_args_are_providers", + "_factory", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + ) @overload def __init__( @@ -179,6 +202,11 @@ def __init__( self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) self._register_arguments() @override @@ -186,12 +214,12 @@ async def resolve(self) -> T_co: if self._override: return typing.cast(T_co, self._override) - args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args] - kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()} + args = await _resolve_arguments(self._args, self._args_are_providers) + kwargs = await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers) result = self._factory( - *args, # type:ignore[arg-type] - **kwargs, # type:ignore[arg-type] + *args, + **kwargs, ) if inspect.isawaitable(result): diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index dd8f1ef..709925c 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -5,6 +5,12 @@ from typing_extensions import override from that_depends.providers import AbstractProvider +from that_depends.providers.base import ( + _resolve_arguments, + _resolve_arguments_sync, + _resolve_keyword_arguments, + _resolve_keyword_arguments_sync, +) from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown @@ -55,14 +61,22 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P self._asyncio_lock = asyncio.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @property def _instance(self) -> T_co | None: @@ -84,10 +98,8 @@ async def resolve(self) -> T_co: self._register_arguments() self._instance = self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) return self._instance @@ -102,8 +114,8 @@ def resolve_sync(self) -> T_co: self._register_arguments() self._instance = self._factory( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) return self._instance diff --git a/that_depends/providers/mixin.py b/that_depends/providers/mixin.py index 5462230..99597af 100644 --- a/that_depends/providers/mixin.py +++ b/that_depends/providers/mixin.py @@ -26,6 +26,22 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> class ProviderWithArguments(abc.ABC): """Interface for providers that require arguments.""" + __slots__ = ("_arguments_registered",) + + def __init__(self) -> None: + """Initialize provider argument registration state.""" + super().__init__() + self._arguments_registered = False + + def _mark_arguments_registered(self) -> bool: + if self._arguments_registered: + return False + self._arguments_registered = True + return True + + def _reset_arguments_registration(self) -> None: + self._arguments_registered = False + @abc.abstractmethod def _register_arguments(self) -> None: """Register arguments for the provider.""" diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index daa470d..7ee0bca 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -6,7 +6,13 @@ from typing_extensions import override -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import ( + AbstractProvider, + _resolve_arguments, + _resolve_arguments_sync, + _resolve_keyword_arguments, + _resolve_keyword_arguments_sync, +) from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown @@ -34,7 +40,18 @@ def my_factory() -> float: """ - __slots__ = "_args", "_asyncio_lock", "_factory", "_instance", "_kwargs", "_override", "_threading_lock" + __slots__ = ( + "_args", + "_args_are_providers", + "_asyncio_lock", + "_factory", + "_instance", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + "_threading_lock", + ) def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: """Initialize the Singleton provider. @@ -52,14 +69,22 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P self._threading_lock: typing.Final = threading.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @override async def resolve(self) -> T_co: @@ -73,10 +98,8 @@ async def resolve(self) -> T_co: return self._instance self._register_arguments() self._instance = self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) return self._instance @@ -92,8 +115,8 @@ def resolve_sync(self) -> T_co: return self._instance self._register_arguments() self._instance = self._factory( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + *_resolve_arguments_sync(self._args, self._args_are_providers), + **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) return self._instance @@ -142,7 +165,17 @@ async def my_async_factory() -> float: """ - __slots__ = "_args", "_asyncio_lock", "_factory", "_instance", "_kwargs", "_override" + __slots__ = ( + "_args", + "_args_are_providers", + "_asyncio_lock", + "_factory", + "_instance", + "_kwargs", + "_kwargs_are_providers", + "_kwargs_items", + "_override", + ) def __init__( self, @@ -164,14 +197,22 @@ def __init__( self._asyncio_lock: typing.Final = asyncio.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs + self._args_are_providers: typing.Final = tuple(isinstance(arg, AbstractProvider) for arg in args) + self._kwargs_items: typing.Final = tuple(kwargs.items()) + self._kwargs_are_providers: typing.Final = tuple( + isinstance(value, AbstractProvider) for _, value in self._kwargs_items + ) def _register_arguments(self) -> None: + if not self._mark_arguments_registered(): + return self._register(self._args) self._register(self._kwargs.values()) def _deregister_arguments(self) -> None: self._deregister(self._args) self._deregister(self._kwargs.values()) + self._reset_arguments_registration() @override async def resolve(self) -> T_co: @@ -186,10 +227,8 @@ async def resolve(self) -> T_co: self._register_arguments() self._instance = await self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *await _resolve_arguments(self._args, self._args_are_providers), + **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), ) return self._instance From e94524d0f94f1c4bb60f201b8c164a7232ab0681 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Apr 2026 16:05:46 +0200 Subject: [PATCH 03/13] feat: further performance optimizations. --- that_depends/injection.py | 76 +++++++++++++-------- that_depends/providers/base.py | 29 +++++++- that_depends/providers/context_resources.py | 65 +++++++++++------- that_depends/providers/local_singleton.py | 2 + that_depends/providers/singleton.py | 6 ++ 5 files changed, 121 insertions(+), 57 deletions(-) diff --git a/that_depends/injection.py b/that_depends/injection.py index 1a74846..abd4edc 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -11,7 +11,7 @@ from that_depends.container import BaseContainer from that_depends.exceptions import TypeNotBoundError from that_depends.meta import BaseContainerMeta -from that_depends.providers import AbstractProvider, ContextResource +from that_depends.providers import AbstractProvider from that_depends.providers.context_resources import ContextScope, ContextScopes, container_context @@ -42,7 +42,7 @@ class _DirectInjectionParameter(typing.NamedTuple): argument_index: int field_name: str provider: AbstractProvider[typing.Any] - scope_init_order: tuple[AbstractProvider[typing.Any], ...] + scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] class _InjectionPlan(typing.NamedTuple): @@ -103,7 +103,12 @@ def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionP dynamic_parameters.append(_InjectionParameter(index, field_name, _INJECT_STRING_PROVIDER, default)) elif isinstance(default, AbstractProvider): direct_parameters.append( - _DirectInjectionParameter(index, field_name, default, default._get_scope_init_order()) # noqa: SLF001 + _DirectInjectionParameter( + index, + field_name, + default, + default._get_scope_context_init_order(), # noqa: SLF001 + ) ) elif isinstance(default, _Provide): dynamic_parameters.append(_InjectionParameter(index, field_name, _INJECT_TYPED_PROVIDER, param.annotation)) @@ -255,7 +260,13 @@ async def _resolve_arguments_async( if direct_parameter.argument_index < len(args) or direct_parameter.field_name in kwargs: continue - await _setup_scope_contexts_async(direct_parameter.scope_init_order, scope, stack, context_providers) + if direct_parameter.scope_context_init_order: + await _setup_scope_contexts_async( + direct_parameter.scope_context_init_order, + scope, + stack, + context_providers, + ) kwargs[direct_parameter.field_name] = await direct_parameter.provider.resolve() for dynamic_parameter in plan.dynamic_parameters: @@ -288,7 +299,13 @@ def _resolve_arguments_sync( if direct_parameter.argument_index < len(args) or direct_parameter.field_name in kwargs: continue - _setup_scope_contexts_sync(direct_parameter.scope_init_order, scope, stack, context_providers) + if direct_parameter.scope_context_init_order: + _setup_scope_contexts_sync( + direct_parameter.scope_context_init_order, + scope, + stack, + context_providers, + ) kwargs[direct_parameter.field_name] = direct_parameter.provider.resolve_sync() for dynamic_parameter in plan.dynamic_parameters: @@ -392,7 +409,9 @@ async def _resolve_provider_with_scope_async( ContextProviderError: if the stack is None. """ - await _setup_scope_contexts_async(provider._get_scope_init_order(), scope, stack, providers) # noqa: SLF001 + scope_context_init_order = provider._get_scope_context_init_order() # noqa: SLF001 + if scope_context_init_order: + await _setup_scope_contexts_async(scope_context_init_order, scope, stack, providers) return await provider.resolve() @@ -408,16 +427,15 @@ async def _setup_scope_contexts_async( if provider in providers: continue providers.add(provider) - if isinstance(provider, ContextResource): - provider_scope = provider.get_scope() - if provider_scope is ContextScopes.ANY or provider_scope is scope: - if stack is None: - msg = ( - f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" - f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." - ) - raise ContextProviderError(msg) - await stack.enter_async_context(provider.context_async(force=True)) + provider_scope = provider._scope # noqa: SLF001 + if provider_scope is ContextScopes.ANY or provider_scope is scope: + if stack is None: + msg = ( + f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" + f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." + ) + raise ContextProviderError(msg) + await stack.enter_async_context(provider.context_async(force=True)) def _resolve_provider_with_scope_sync( @@ -426,7 +444,9 @@ def _resolve_provider_with_scope_sync( stack: _SyncInjectionStack | None, providers: set[AbstractProvider[typing.Any]], ) -> T: - _setup_scope_contexts_sync(provider._get_scope_init_order(), scope, stack, providers) # noqa: SLF001 + scope_context_init_order = provider._get_scope_context_init_order() # noqa: SLF001 + if scope_context_init_order: + _setup_scope_contexts_sync(scope_context_init_order, scope, stack, providers) return provider.resolve_sync() @@ -442,18 +462,16 @@ def _setup_scope_contexts_sync( if provider in providers: continue providers.add(provider) - - if isinstance(provider, ContextResource): - provider_scope = provider.get_scope() - if provider_scope is ContextScopes.ANY or provider_scope is scope: - if stack is None: - msg = ( - f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" - f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." - ) - raise ContextProviderError(msg) - _, exit_state = provider._enter_injection_context_sync(force=True) # noqa: SLF001 - stack.push_exit_state(exit_state) + provider_scope = provider._scope # noqa: SLF001 + if provider_scope is ContextScopes.ANY or provider_scope is scope: + if stack is None: + msg = ( + f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" + f"Note: @inject cannot initialize context for ContextResources when wrapping a generator." + ) + raise ContextProviderError(msg) + _, exit_state = provider._enter_injection_context_sync(force=True) # noqa: SLF001 + stack.push_exit_state(exit_state) class StringProviderDefinition: diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 3b25f62..4c3e73a 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -90,6 +90,8 @@ def __init__(self) -> None: super().__init__() self._children: set[AbstractProvider[typing.Any]] = set() self._parents: set[AbstractProvider[typing.Any]] = set() + self._is_context_resource = False + self._scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None self._scope_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None self._override: typing.Any = None self._bindings: set[type] = set() @@ -160,6 +162,7 @@ def _invalidate_scope_init_order(self) -> None: if provider in visited: continue visited.add(provider) + provider._scope_context_init_order = None # noqa: SLF001 provider._scope_init_order = None # noqa: SLF001 stack.extend(provider._children) # noqa: SLF001 @@ -185,6 +188,28 @@ def _get_scope_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]: self._scope_init_order = tuple(ordered) return self._scope_init_order + def _get_scope_context_init_order(self) -> tuple["AbstractProvider[typing.Any]", ...]: + if self._scope_context_init_order is not None: + return self._scope_context_init_order + + if isinstance(self, ProviderWithArguments): + self._register_arguments() + + ordered: list[AbstractProvider[typing.Any]] = [] + seen: set[AbstractProvider[typing.Any]] = set() + + for parent in self._parents: + for ancestor in parent._get_scope_context_init_order(): # noqa: SLF001 + if ancestor not in seen: + seen.add(ancestor) + ordered.append(ancestor) + + if self._is_context_resource and self not in seen: + ordered.append(self) + + self._scope_context_init_order = tuple(ordered) + return self._scope_context_init_order + def add_child_provider(self, provider: "AbstractProvider[typing.Any]") -> None: """Add a child provider to the current provider. @@ -461,7 +486,6 @@ async def resolve(self) -> T_co: return context.instance self._register_arguments() - cm: typing.ContextManager[T_co] | typing.AsyncContextManager[T_co] = self._creator( *await _resolve_arguments(self._args, self._args_are_providers), **await _resolve_keyword_arguments(self._kwargs_items, self._kwargs_are_providers), @@ -497,13 +521,12 @@ def resolve_sync(self) -> T_co: raise RuntimeError(msg) self._register_arguments() - cm = self._creator( *_resolve_arguments_sync(self._args, self._args_are_providers), **_resolve_keyword_arguments_sync(self._kwargs_items, self._kwargs_are_providers), ) context.context_stack = contextlib.ExitStack() - context.instance = context.context_stack.enter_context(cm) # type:ignore[arg-type] + context.instance = context.context_stack.enter_context(cm) # type: ignore[arg-type] return context.instance diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index e815333..fd0c8d9 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -42,23 +42,25 @@ class InvalidContextError(RuntimeError): class _SyncInjectionExitState(typing.Generic[T_co]): - __slots__ = ("_provider", "_temp_token", "_token") + __slots__ = ("_context", "_context_item", "_token") def __init__( self, - provider: "ContextResource[T_co]", - token: Token[ResourceContext[T_co]] | None, - temp_token: Token[ResourceContext[T_co]] | None, + context: ContextVar[ResourceContext[T_co]], + context_item: ResourceContext[T_co], + token: Token[ResourceContext[T_co]], ) -> None: - self._provider = provider + self._context = context + self._context_item = context_item self._token = token - self._temp_token = temp_token def close(self) -> None: - with self._provider._lock: # noqa: SLF001 - self._provider._token = self._temp_token # noqa: SLF001 - self._provider._exit_context_sync() # noqa: SLF001 - self._provider._token = self._token # noqa: SLF001 + context_stack = self._context_item.context_stack + if context_stack is not None: + context_stack.close() # type: ignore[union-attr] + self._context_item.context_stack = None + self._context_item.instance = None + self._context.reset(self._token) class _SyncContextResourceContext(contextlib.ContextDecorator, AbstractContextManager[ResourceContext[T_co]]): @@ -275,16 +277,20 @@ class ContextResource( @override async def resolve(self) -> T_co: + if not self._strict_scope or self._scope is ContextScopes.ANY: + return await super().resolve() current_scope = get_current_scope() - if not self._strict_scope or self._scope is ContextScopes.ANY or self._scope is current_scope: + if self._scope is current_scope: return await super().resolve() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @override def resolve_sync(self) -> T_co: + if not self._strict_scope or self._scope is ContextScopes.ANY: + return super().resolve_sync() current_scope = get_current_scope() - if not self._strict_scope or self._scope is ContextScopes.ANY or self._scope is current_scope: + if self._scope is current_scope: return super().resolve_sync() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @@ -321,6 +327,7 @@ def __init__( """ super().__init__(creator, *args, **kwargs) + self._is_context_resource = True self._from_creator = creator self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") self._token: Token[ResourceContext[T_co]] | None = None @@ -404,24 +411,28 @@ def _enter_injection_context_sync( if self._is_async: msg = "Please use async context instead." raise RuntimeError(msg) + if not force and self._scope is not ContextScopes.ANY: + current_scope = get_current_scope() + if self._scope is not current_scope: + msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" + raise InvalidContextError(msg) - token = self._token - with self._lock: - value = self._enter_context_sync(force=force) - temp_token = self._token - - return value, _SyncInjectionExitState(self, token, temp_token) + context_item: ResourceContext[T_co] = ResourceContext(is_async=False) + token = self._context.set(context_item) + return context_item, _SyncInjectionExitState(self._context, context_item, token) async def _enter_context_async(self, force: bool = False) -> ResourceContext[T_co]: return self._enter(force) def _enter(self, force: bool = False) -> ResourceContext[T_co]: - current_scope = get_current_scope() - if not force and self._scope is not ContextScopes.ANY and self._scope is not current_scope: - msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" - raise InvalidContextError(msg) - self._token = self._context.set(ResourceContext(is_async=self._is_async)) - return self._context.get() + if not force and self._scope is not ContextScopes.ANY: + current_scope = get_current_scope() + if self._scope is not current_scope: + msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" + raise InvalidContextError(msg) + context_item: ResourceContext[T_co] = ResourceContext(is_async=self._is_async) + self._token = self._context.set(context_item) + return context_item def _exit_context_sync(self) -> None: if self._token is None: @@ -430,7 +441,11 @@ def _exit_context_sync(self) -> None: try: context_item = self._context.get() - context_item.tear_down_sync() + context_stack = context_item.context_stack + if context_stack is not None and ResourceContext.is_context_stack_sync(context_stack): + context_stack.close() + context_item.context_stack = None + context_item.instance = None finally: self._context.reset(self._token) diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index 709925c..5146f52 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -90,6 +90,8 @@ def _instance(self, value: T_co | None) -> None: async def resolve(self) -> T_co: if self._override is not None: return typing.cast(T_co, self._override) + if self._instance is not None: + return self._instance async with self._asyncio_lock: if self._instance is not None: diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index 7ee0bca..6e3cd56 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -91,6 +91,8 @@ async def resolve(self) -> T_co: if self._override is not None: self._register_arguments() return typing.cast(T_co, self._override) + if self._instance is not None: + return self._instance # lock to prevent resolving several times async with self._asyncio_lock: @@ -108,6 +110,8 @@ def resolve_sync(self) -> T_co: if self._override is not None: self._register_arguments() return typing.cast(T_co, self._override) + if self._instance is not None: + return self._instance # lock to prevent resolving several times with self._threading_lock: @@ -218,6 +222,8 @@ def _deregister_arguments(self) -> None: async def resolve(self) -> T_co: if self._override is not None: return typing.cast(T_co, self._override) + if self._instance is not None: + return self._instance # lock to prevent resolving several times async with self._asyncio_lock: From bfc13a5235e819e5496648ff98c6382e012d91f2 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Apr 2026 16:09:31 +0200 Subject: [PATCH 04/13] benchmark: Updated benchmark results. --- examples/benchmark/RESULTS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/benchmark/RESULTS.md b/examples/benchmark/RESULTS.md index 90ff8dd..ef211fb 100644 --- a/examples/benchmark/RESULTS.md +++ b/examples/benchmark/RESULTS.md @@ -3,6 +3,8 @@ Based on this [benchmark](injection.py): | version / iterations | 10^4 | 10^5 | 10^6 | |----------------------|--------|--------|---------| +| 4.0 | 0.0399 | 0.3950 | 3.8107 | +| 3.9.2 | 0.0920 | 0.9804 | 9.1009 | | 3.2.0 | 0.1039 | 1.0563 | 10.3576 | | 3.0.0.a2 | 0.0870 | 0.9430 | 8.8815 | | 3.0.0.a1 | 0.1399 | 1.4136 | 14.1829 | From 2eaa4698e917ea0a35bc7cccc936d4055ccf9e02 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 20 Apr 2026 17:06:29 +0200 Subject: [PATCH 05/13] feat: allow proper covariant typing. --- Justfile | 2 + docs/providers/collections.md | 8 +-- docs/providers/context-resources.md | 4 +- pyproject.toml | 2 +- tests/experimental/test_container_2.py | 16 ++++++ tests/providers/test_collections.py | 6 ++- tests/providers/test_context_resources.py | 8 ++- tests/test_meta.py | 7 +++ that_depends/container.py | 40 ++++++++++++++- that_depends/entities/resource_context.py | 48 +++++++++++++++--- that_depends/experimental/providers.py | 40 +++++++-------- that_depends/integrations/faststream.py | 17 ++++--- that_depends/meta.py | 37 ++------------ that_depends/providers/base.py | 45 +++++++++-------- that_depends/providers/collection.py | 43 ++++++++-------- that_depends/providers/context_resources.py | 54 ++++++++++----------- that_depends/providers/factories.py | 24 +++------ that_depends/providers/local_singleton.py | 22 ++++----- that_depends/providers/singleton.py | 20 +++----- 19 files changed, 259 insertions(+), 184 deletions(-) diff --git a/Justfile b/Justfile index 7bab81d..d28ab97 100644 --- a/Justfile +++ b/Justfile @@ -9,11 +9,13 @@ lint: uv run ruff format uv run ruff check --fix uv run mypy . + uv run pyrefly check lint-ci: uv run ruff format --check uv run ruff check --no-fix uv run mypy . + uv run pyrefly check test *args: uv run --no-sync pytest {{ args }} diff --git a/docs/providers/collections.md b/docs/providers/collections.md index 34294ba..8615441 100644 --- a/docs/providers/collections.md +++ b/docs/providers/collections.md @@ -3,7 +3,7 @@ There are several collection providers: `List` and `Dict` ## List - List provider contains other providers. -- Resolves into list of dependencies. +- Resolves into an immutable sequence of dependencies. ```python import random @@ -16,12 +16,12 @@ class DIContainer(BaseContainer): DIContainer.numbers_sequence.resolve_sync() -# [0.3035656170071561, 0.8280498192037787] +# (0.3035656170071561, 0.8280498192037787) ``` ## Dict - Dict provider is a collection of named providers. -- Resolves into dict of dependencies. +- Resolves into a read-only mapping of dependencies. ```python import random @@ -34,5 +34,5 @@ class DIContainer(BaseContainer): DIContainer.numbers_map.resolve_sync() -# {'key1': 0.6851384528299208, 'key2': 0.41044920948045294} +# mappingproxy({'key1': 0.6851384528299208, 'key2': 0.41044920948045294}) ``` diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index 8aaa327..87e00c9 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -8,7 +8,7 @@ To interact with both types of contexts, there are two separate interfaces: 1. Use the `container_context()` context manager to interact with the global context and manage `ContextResource` providers. -2. Directly manage a `ContextResource` context by using the `SupportsContext` interface, which both containers +2. Directly manage a `ContextResource` context by using the `SupportsContext` protocol, which both containers and `ContextResource` providers implement. --- @@ -185,7 +185,7 @@ async with container_context(MyContainer.async_resource): ... ``` -It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` interface described +It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` protocol described [here](#quick-reference). ### Context Hierarchy diff --git a/pyproject.toml b/pyproject.toml index 2b9359e..25857d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dev = [ "mkdocs>=1.6.1", "pytest-randomly", "mkdocs-llmstxt>=0.4.0", - "pydantic<2.12.4" # causes tests to fail because of: TypeError: _eval_type() got an unexpected keyword argument 'prefer_fwd_module' + "pyrefly>=0.61.1", ] [build-system] diff --git a/tests/experimental/test_container_2.py b/tests/experimental/test_container_2.py index f5a9ec4..30b1156 100644 --- a/tests/experimental/test_container_2.py +++ b/tests/experimental/test_container_2.py @@ -22,6 +22,12 @@ def __hash__(self) -> int: return 0 # pragma: nocover +class _BrokenContextObject(providers.Object[int]): + def get_scope(self) -> ContextScopes | None: + msg = "_missing_scope" + raise AttributeError(msg) + + async def _async_creator() -> AsyncIterator[float]: yield random.random() @@ -30,6 +36,9 @@ def _sync_creator() -> Iterator[_RandomWrapper]: yield _RandomWrapper() +broken_context_provider = _BrokenContextObject(1) + + class Container2(BaseContainer): """Test Container 2.""" @@ -139,6 +148,13 @@ async def test_lazy_provider_not_implemented() -> None: await lazy_provider.tear_down() +def test_lazy_provider_not_implemented_when_context_method_raises_attribute_error() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.broken_context_provider") + + with pytest.raises(NotImplementedError): + lazy_provider.get_scope() + + def test_lazy_provider_attr_getter() -> None: lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider") with lazy_provider.context_sync(force=True): diff --git a/tests/providers/test_collections.py b/tests/providers/test_collections.py index 14a58ae..5d028a3 100644 --- a/tests/providers/test_collections.py +++ b/tests/providers/test_collections.py @@ -28,7 +28,9 @@ async def test_list_provider() -> None: sync_resource = await DIContainer.sync_resource() async_resource = await DIContainer.async_resource() - assert sequence == [sync_resource, async_resource] + assert sequence == (sync_resource, async_resource) + with pytest.raises(TypeError): + typing.cast(typing.Any, sequence)[0] = sync_resource def test_list_failed_sync_resolve() -> None: @@ -48,6 +50,8 @@ async def test_dict_provider() -> None: assert mapping == {"sync_resource": sync_resource, "async_resource": async_resource} assert mapping == DIContainer.mapping.resolve_sync() + with pytest.raises(TypeError): + typing.cast(typing.Any, mapping)["sync_resource"] = sync_resource @pytest.mark.parametrize("provider", [DIContainer.sequence, DIContainer.mapping]) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 9420097..73832ad 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -186,6 +186,7 @@ async def test_early_exit_of_container_context() -> None: async def test_resource_context_early_teardown() -> None: context: ResourceContext[str] = ResourceContext(is_async=True) + assert context.is_async is True assert context.context_stack is None context.tear_down_sync() assert context.context_stack is None @@ -193,7 +194,7 @@ async def test_resource_context_early_teardown() -> None: async def test_teardown_sync_container_context_with_async_resource() -> None: resource_context: ResourceContext[typing.Any] = ResourceContext(is_async=True) - resource_context.context_stack = AsyncExitStack() + resource_context.set_context_state(context_stack=AsyncExitStack()) message = "Cannot tear down async context in sync mode" with pytest.raises(RuntimeError, match=message): resource_context.tear_down_sync() @@ -699,6 +700,11 @@ class _Container(BaseContainer): ... assert _Container.get_scope() is ContextScopes.ANY + class _UnsetScopeContainer(BaseContainer): + default_scope = None + + assert _UnsetScopeContainer.get_scope() is ContextScopes.ANY + class _ScopedContainer(BaseContainer): default_scope = ContextScopes.INJECT diff --git a/tests/test_meta.py b/tests/test_meta.py index b4efc14..6f1a935 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -9,6 +9,13 @@ class _Test(metaclass=BaseContainerMeta): assert _Test.get_scope() == ContextScopes.ANY +def test_base_container_meta_uses_explicit_default_scope() -> None: + class _Test(metaclass=BaseContainerMeta): + default_scope = ContextScopes.INJECT + + assert _Test.get_scope() == ContextScopes.INJECT + + def test_base_container_meta_has_correct_default_name() -> None: class _Test(metaclass=BaseContainerMeta): pass diff --git a/that_depends/container.py b/that_depends/container.py index f42d08a..d23afe4 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,13 +1,13 @@ import inspect import typing -from contextlib import contextmanager +from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from typing import overload from typing_extensions import override from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, AsyncSingleton, Resource, Singleton -from that_depends.providers.context_resources import ContextScope, ContextScopes +from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes if typing.TYPE_CHECKING: @@ -26,6 +26,42 @@ class BaseContainer(metaclass=BaseContainerMeta): containers: list[type["BaseContainer"]] default_scope: ContextScope | None = ContextScopes.ANY + @classmethod + def get_scope(cls) -> ContextScope | None: + """Return the default scope used by the container.""" + if cls.default_scope is not None: + return cls.default_scope + return ContextScopes.ANY + + @classmethod + @asynccontextmanager + async def context_async(cls, force: bool = False) -> typing.AsyncIterator[None]: + """Enter async contexts for all connected containers and context resources.""" + async with AsyncExitStack() as stack: + for container in cls.get_containers(): + await stack.enter_async_context(container.context_async(force=force)) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource): + await stack.enter_async_context(provider.context_async(force=force)) + yield + + @classmethod + @contextmanager + def context_sync(cls, force: bool = False) -> typing.Iterator[None]: + """Enter sync contexts for all connected containers and sync context resources.""" + with ExitStack() as stack: + for container in cls.get_containers(): + stack.enter_context(container.context_sync(force=force)) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource) and not provider._is_async: # noqa: SLF001 + stack.enter_context(provider.context_sync(force=force)) + yield + + @classmethod + def supports_context_sync(cls) -> bool: + """Indicate that container classes support sync context management.""" + return True + @classmethod @overload def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... diff --git a/that_depends/entities/resource_context.py b/that_depends/entities/resource_context.py index 8e10b6e..9e0e10a 100644 --- a/that_depends/entities/resource_context.py +++ b/that_depends/entities/resource_context.py @@ -15,7 +15,7 @@ class ResourceContext(SupportsTeardown, typing.Generic[T_co]): """Class to manage a resources' context.""" - __slots__ = "asyncio_lock", "context_stack", "instance", "is_async", "threading_lock" + __slots__ = "_context_stack", "_instance", "_is_async", "asyncio_lock", "threading_lock" def __init__(self, is_async: bool) -> None: """Create a new ResourceContext instance. @@ -26,11 +26,45 @@ def __init__(self, is_async: bool) -> None: For example within a ``async with container_context(Container): ...`` statement. """ - self.instance: T_co | None = None + self._instance: T_co | None = None self.asyncio_lock: typing.Final = asyncio.Lock() self.threading_lock: typing.Final = threading.Lock() - self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None - self.is_async = is_async + self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None + self._is_async = is_async + + @property + def instance(self) -> T_co | None: + """Return the currently cached resource instance, if any.""" + return self._instance + + @property + def context_stack(self) -> contextlib.AsyncExitStack | contextlib.ExitStack | None: + """Return the active context stack, if any.""" + return self._context_stack + + @property + def is_async(self) -> bool: + """Indicate whether this context was created in async mode.""" + return self._is_async + + def set_context_state( + self, + *, + instance: T_co | None = None, + context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None, + ) -> None: + """Set the context state of the resource. + + Args: + instance: instance. + context_stack: stack. + + Returns: + None + + """ + self._instance = instance + self._context_stack = context_stack @staticmethod def is_context_stack_async( @@ -56,8 +90,7 @@ async def tear_down(self, propagate: bool = True) -> None: await self.context_stack.aclose() elif self.is_context_stack_sync(self.context_stack): self.context_stack.close() - self.context_stack = None - self.instance = None + self.set_context_state(instance=None, context_stack=None) @override def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: @@ -67,8 +100,7 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> if self.is_context_stack_sync(self.context_stack): self.context_stack.close() - self.context_stack = None - self.instance = None + self.set_context_state(instance=None, context_stack=None) elif self.is_context_stack_async(self.context_stack): msg = "Cannot tear down async context in sync mode" if raise_on_async: diff --git a/that_depends/experimental/providers.py b/that_depends/experimental/providers.py index c79becf..6a5b1ec 100644 --- a/that_depends/experimental/providers.py +++ b/that_depends/experimental/providers.py @@ -7,7 +7,7 @@ from that_depends import ContextScope from that_depends.providers import AbstractProvider -from that_depends.providers.context_resources import CT, SupportsContext +from that_depends.providers.context_resources import CT_co, SupportsContext from that_depends.providers.mixin import SupportsTeardown @@ -55,32 +55,19 @@ def __init__( @typing_extensions.override def get_scope(self) -> ContextScope | None: - provider = self._get_provider() - if isinstance(provider, SupportsContext): - return provider.get_scope() - msg = "Underlying provider does not support context scopes" - raise NotImplementedError(msg) + return typing.cast(ContextScope | None, self._call_context_method("get_scope")) @typing_extensions.override - def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: - provider = self._get_provider() - if isinstance(provider, SupportsContext): - return provider.context_async(force) - msg = "Underlying provider does not support context management" - raise NotImplementedError(msg) + def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT_co]: + return typing.cast(typing.AsyncContextManager[CT_co], self._call_context_method("context_async", force)) @typing_extensions.override - def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: - provider = self._get_provider() - if isinstance(provider, SupportsContext): - return provider.context_sync(force) - msg = "Underlying provider does not support context management" - raise NotImplementedError(msg) + def context_sync(self, force: bool = False) -> typing.ContextManager[CT_co]: + return typing.cast(typing.ContextManager[CT_co], self._call_context_method("context_sync", force)) @typing_extensions.override def supports_context_sync(self) -> bool: - provider = self._get_provider() - return isinstance(provider, SupportsContext) and provider.supports_context_sync() + return typing.cast(bool, self._call_context_method("supports_context_sync")) @typing_extensions.override async def tear_down(self, propagate: bool = True) -> None: @@ -139,6 +126,19 @@ def _get_provider(self) -> AbstractProvider[Any]: self._provider = cast(AbstractProvider[Any], provider) return self._provider + def _call_context_method(self, method_name: str, *args: object) -> object: + provider = self._get_provider() + try: + method = getattr(type(provider), method_name) + except AttributeError as e: + msg = "Underlying provider does not support context management" + raise NotImplementedError(msg) from e + try: + return method(provider, *args) + except AttributeError as e: + msg = "Underlying provider does not support context management" + raise NotImplementedError(msg) from e + @typing_extensions.override async def resolve(self) -> Any: provider = self._get_provider() diff --git a/that_depends/integrations/faststream.py b/that_depends/integrations/faststream.py index 67d9fc6..f576518 100644 --- a/that_depends/integrations/faststream.py +++ b/that_depends/integrations/faststream.py @@ -22,7 +22,7 @@ class DIContextMiddleware(BaseMiddleware): def __init__( self, - *context_items: SupportsContext[Any], + *context_items: SupportsContext[typing.Any], msg: AnyMsg | None = None, context: Optional["ContextRepo"] = None, global_context: dict[str, Any] | Unset = UNSET, @@ -31,7 +31,7 @@ def __init__( """Initialize the container context middleware. Args: - *context_items (SupportsContext[Any]): Context items to initialize. + *context_items: Context-capable providers or container classes to initialize. msg (Any): Message object. context (ContextRepo): Context repository. global_context (dict[str, Any] | Unset): Global context to initialize the container. @@ -40,7 +40,7 @@ def __init__( """ super().__init__(msg, context=context) # type: ignore[arg-type] self._context: container_context | None = None - self._context_items = set(context_items) + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context = global_context self._scope = scope @@ -93,21 +93,26 @@ class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef] def __init__( self, - *context_items: SupportsContext[Any], + *context_items: SupportsContext[typing.Any], + msg: object | None = None, + context: object | None = None, global_context: dict[str, Any] | Unset = UNSET, scope: ContextScope | Unset = UNSET, ) -> None: """Initialize the container context middleware. Args: - *context_items (SupportsContext[Any]): Context items to initialize. + *context_items: Context-capable providers or container classes to initialize. + msg: Message object passed by faststream. + context: Context repository passed by faststream. global_context (dict[str, Any] | Unset): Global context to initialize the container. scope (ContextScope | Unset): Context scope to initialize the container. """ + del msg, context super().__init__() # type: ignore[call-arg] self._context: container_context | None = None - self._context_items = set(context_items) + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context = global_context self._scope = scope diff --git a/that_depends/meta.py b/that_depends/meta.py index 21507a4..164722e 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -2,7 +2,6 @@ import typing import warnings from collections.abc import MutableMapping -from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from threading import Lock from typing import TYPE_CHECKING @@ -11,7 +10,7 @@ import that_depends from that_depends.exceptions import TypeNotBoundError from that_depends.providers import AbstractProvider, Resource -from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes, SupportsContext +from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes from that_depends.providers.mixin import SupportsTeardown @@ -43,41 +42,15 @@ def __setitem__(self, key: str, value: typing.Any) -> None: super().__setitem__(key, value) -class BaseContainerMeta(abc.ABCMeta, SupportsContext[None]): +class BaseContainerMeta(abc.ABCMeta): """Metaclass for BaseContainer.""" - @override def get_scope(cls) -> ContextScope | None: + """Return the default scope used by the container.""" if scope := getattr(cls, "default_scope", None): return typing.cast(ContextScope | None, scope) return ContextScopes.ANY - @asynccontextmanager - @override - async def context_async(cls, force: bool = False) -> typing.AsyncIterator[None]: - async with AsyncExitStack() as stack: - for container in cls.get_containers(): - await stack.enter_async_context(container.context_async(force=force)) - for provider in cls.get_providers().values(): - if isinstance(provider, ContextResource): - await stack.enter_async_context(provider.context_async(force=force)) - yield - - @contextmanager - @override - def context_sync(cls, force: bool = False) -> typing.Iterator[None]: - with ExitStack() as stack: - for container in cls.get_containers(): - stack.enter_context(container.context_sync(force=force)) - for provider in cls.get_providers().values(): - if isinstance(provider, ContextResource) and not provider._is_async: # noqa: SLF001 - stack.enter_context(provider.context_sync(force=force)) - yield - - @override - def supports_context_sync(cls) -> bool: - return True - _instances: typing.ClassVar[dict[str, type["BaseContainer"]]] = {} _MUTABLE_ATTRS = ( @@ -114,10 +87,6 @@ def name(cls) -> str: def __prepare__(cls, name: str, bases: tuple[type, ...], /, **kwds: typing.Any) -> MutableMapping[str, object]: return _ContainerMetaDict() - @override - def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type: - return super().__new__(cls, name, bases, namespace) - @classmethod def get_instances(cls) -> dict[str, type["BaseContainer"]]: """Get all instances that inherit from BaseContainer.""" diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 9c30bdf..5afc390 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -18,7 +18,10 @@ P = typing.ParamSpec("P") ResourceCreatorType: typing.TypeAlias = typing.Callable[ P, - typing.Iterator[T_co] | typing.AsyncIterator[T_co] | typing.ContextManager[T_co] | typing.AsyncContextManager[T_co], + typing.Iterator[T_co] + | typing.AsyncIterator[T_co] + | contextlib.AbstractContextManager[T_co] + | contextlib.AbstractAsyncContextManager[T_co], ] @@ -314,10 +317,10 @@ def __init__( elif inspect.isgeneratorfunction(creator): self._is_async = False self._creator = contextlib.contextmanager(creator) - elif isinstance(creator, type) and issubclass(creator, typing.AsyncContextManager): + elif isinstance(creator, type) and hasattr(creator, "__aenter__") and hasattr(creator, "__aexit__"): self._is_async = True self._creator = creator - elif isinstance(creator, type) and issubclass(creator, typing.ContextManager): + elif isinstance(creator, type) and hasattr(creator, "__enter__") and hasattr(creator, "__exit__"): self._is_async = False self._creator = creator else: @@ -351,23 +354,26 @@ async def resolve(self) -> T_co: self._register_arguments() - cm: typing.ContextManager[T_co] | typing.AsyncContextManager[T_co] = self._creator( + cm: contextlib.AbstractContextManager[T_co] | contextlib.AbstractAsyncContextManager[T_co] = self._creator( *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) + stack: contextlib.AsyncExitStack | contextlib.ExitStack - if isinstance(cm, typing.AsyncContextManager): - context.context_stack = contextlib.AsyncExitStack() - context.instance = await context.context_stack.enter_async_context(cm) + if isinstance(cm, contextlib.AbstractAsyncContextManager): + stack = contextlib.AsyncExitStack() + instance = await stack.enter_async_context(cm) - elif isinstance(cm, typing.ContextManager): - context.context_stack = contextlib.ExitStack() - context.instance = context.context_stack.enter_context(cm) + elif isinstance(cm, contextlib.AbstractContextManager): + stack = contextlib.ExitStack() + instance = stack.enter_context(cm) else: # pragma: no cover typing.assert_never(cm) - return context.instance + context.set_context_state(instance=instance, context_stack=stack) + + return instance @override def resolve_sync(self) -> T_co: @@ -387,14 +393,15 @@ def resolve_sync(self) -> T_co: self._register_arguments() - cm = self._creator( + cm: contextlib.AbstractContextManager[T_co] = self._creator( *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) - context.context_stack = contextlib.ExitStack() - context.instance = context.context_stack.enter_context(cm) + stack = contextlib.ExitStack() + instance = stack.enter_context(cm) + context.set_context_state(instance=instance, context_stack=stack) - return context.instance + return instance def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401 @@ -433,11 +440,11 @@ def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None: self._attrs = [attr_name] @override - def __getattr__(self, attr: str) -> "AttrGetter[T_co]": - if attr.startswith("_"): - msg = f"'{type(self)}' object has no attribute '{attr}'" + def __getattr__(self, attr_name: str) -> "AttrGetter[T_co]": + if attr_name.startswith("_"): + msg = f"'{type(self)}' object has no attribute '{attr_name}'" raise AttributeError(msg) - self._attrs.append(attr) + self._attrs.append(attr_name) return self @override diff --git a/that_depends/providers/collection.py b/that_depends/providers/collection.py index 28533d1..08e0a4c 100644 --- a/that_depends/providers/collection.py +++ b/that_depends/providers/collection.py @@ -1,4 +1,6 @@ import typing +from collections.abc import Mapping, Sequence +from types import MappingProxyType from typing_extensions import override @@ -8,10 +10,11 @@ T_co = typing.TypeVar("T_co", covariant=True) -class List(AbstractProvider[list[T_co]]): - """Provides multiple resources as a list. +class List(AbstractProvider[Sequence[T_co]]): + """Provides multiple resources as a read-only sequence. - The `List` provider resolves multiple dependencies into a list. + The `List` provider resolves multiple dependencies into an immutable tuple + exposed through the `Sequence` interface. Example: ```python @@ -24,12 +27,12 @@ class List(AbstractProvider[list[T_co]]): # Synchronous resolution resolved_list = list_provider.resolve_sync() - print(resolved_list) # Output: [1, 2] + print(tuple(resolved_list)) # Output: (1, 2) # Asynchronous resolution import asyncio resolved_list_async = asyncio.run(list_provider.resolve()) - print(resolved_list_async) # Output: [1, 2] + print(tuple(resolved_list_async)) # Output: (1, 2) ``` """ @@ -53,22 +56,24 @@ def __getattr__(self, attr_name: str) -> typing.Any: raise AttributeError(msg) @override - async def resolve(self) -> list[T_co]: - return [await x.resolve() for x in self._providers] + async def resolve(self) -> Sequence[T_co]: + resolved = [await x.resolve() for x in self._providers] + return tuple(resolved) @override - def resolve_sync(self) -> list[T_co]: - return [x.resolve_sync() for x in self._providers] + def resolve_sync(self) -> Sequence[T_co]: + resolved = [x.resolve_sync() for x in self._providers] + return tuple(resolved) @override - async def __call__(self) -> list[T_co]: + async def __call__(self) -> Sequence[T_co]: return await self.resolve() -class Dict(AbstractProvider[dict[str, T_co]]): - """Provides multiple resources as a dictionary. +class Dict(AbstractProvider[Mapping[str, T_co]]): + """Provides multiple resources as a read-only mapping. - The `Dict` provider resolves multiple named dependencies into a dictionary. + The `Dict` provider resolves multiple named dependencies into a read-only mapping. Example: ```python @@ -81,12 +86,12 @@ class Dict(AbstractProvider[dict[str, T_co]]): # Synchronous resolution resolved_dict = dict_provider.resolve_sync() - print(resolved_dict) # Output: {"key1": 1, "key2": 2} + print(dict(resolved_dict)) # Output: {"key1": 1, "key2": 2} # Asynchronous resolution import asyncio resolved_dict_async = asyncio.run(dict_provider.resolve()) - print(resolved_dict_async) # Output: {"key1": 1, "key2": 2} + print(dict(resolved_dict_async)) # Output: {"key1": 1, "key2": 2} ``` """ @@ -110,9 +115,9 @@ def __getattr__(self, attr_name: str) -> typing.Any: raise AttributeError(msg) @override - async def resolve(self) -> dict[str, T_co]: - return {key: await provider.resolve() for key, provider in self._providers.items()} + async def resolve(self) -> Mapping[str, T_co]: + return MappingProxyType({key: await provider.resolve() for key, provider in self._providers.items()}) @override - def resolve_sync(self) -> dict[str, T_co]: - return {key: provider.resolve_sync() for key, provider in self._providers.items()} + def resolve_sync(self) -> Mapping[str, T_co]: + return MappingProxyType({key: provider.resolve_sync() for key, provider in self._providers.items()}) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 4dab61a..7495390 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -3,7 +3,6 @@ import inspect import logging import typing -from abc import abstractmethod from collections.abc import Iterable from contextlib import AbstractAsyncContextManager, AbstractContextManager from contextvars import ContextVar, Token @@ -11,7 +10,7 @@ from types import TracebackType from typing import Final, overload -from typing_extensions import TypeIs, override +from typing_extensions import Protocol, TypeIs, override, runtime_checkable from that_depends.entities.resource_context import ResourceContext from that_depends.providers.base import AbstractResource @@ -102,22 +101,21 @@ def _enter_named_scope(scope: ContextScope) -> typing.Iterator[ContextScope]: T = typing.TypeVar("T") -CT = typing.TypeVar("CT") +CT_co = typing.TypeVar("CT_co", covariant=True) -class SupportsContext(typing.Generic[CT]): +@runtime_checkable +class SupportsContext(Protocol[CT_co]): """Interface for resources that support context initialization. This interface defines methods to create synchronous and asynchronous context managers, as well as a function decorator for context initialization. """ - @abstractmethod def get_scope(self) -> ContextScope | None: """Return the scope of the resource.""" - @abstractmethod - def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: + def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT_co]: """Create an async context manager for this resource. Args: @@ -133,9 +131,9 @@ def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: ``` """ + ... - @abstractmethod - def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: + def context_sync(self, force: bool = False) -> typing.ContextManager[CT_co]: """Create a sync context manager for this resource. Args: @@ -151,8 +149,8 @@ def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: ``` """ + ... - @abstractmethod def supports_context_sync(self) -> bool: """Check whether the resource supports sync context. @@ -160,6 +158,10 @@ def supports_context_sync(self) -> bool: bool: True if sync context is supported, False otherwise. """ + ... + + +BaseContainerType: typing.TypeAlias = type["BaseContainer"] def _get_container_context() -> dict[str, typing.Any] | None: @@ -274,7 +276,7 @@ def __init__( """ super().__init__(creator, *args, **kwargs) - self._from_creator = creator + self._from_creator: typing.Callable[..., typing.Iterator[T_co] | typing.AsyncIterator[T_co]] = creator self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") self._token: Token[ResourceContext[T_co]] | None = None self._async_lock: Final = asyncio.Lock() @@ -334,7 +336,7 @@ def with_config(self, scope: ContextScope | None, strict_scope: bool = False) -> if strict_scope and scope == ContextScopes.ANY: msg = f"Cannot set strict_scope with scope {scope}." raise ValueError(msg) - r = ContextResource(self._from_creator, *self._args, **self._kwargs) # type: ignore[arg-type] + r = ContextResource(self._from_creator, *self._args, **self._kwargs) r._scope = scope r._strict_scope = strict_scope @@ -423,9 +425,6 @@ def _fetch_context(self) -> ResourceContext[T_co]: raise RuntimeError(msg) from e -ContainerType = typing.TypeVar("ContainerType", bound="type[BaseContainer]") - - class container_context(AbstractContextManager[ContextType], AbstractAsyncContextManager[ContextType]): # noqa: N801 """Initialize contexts for the provided containers or resources. @@ -434,15 +433,13 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex """ __slots__ = ( - "_containers", "_context_items", - "_context_providers", "_context_stack", "_context_token", + "_entered_context_items", "_global_context", "_initial_context", "_preserve_global_context", - "_providers", "_reset_resource_context", "_scope", "_scope_token", @@ -458,7 +455,7 @@ def __init__( """Initialize a new container context. Args: - *context_items (SupportsContext[Any]): Context items to initialize a new context for. + *context_items: Context-capable providers or container classes to initialize. global_context (dict[str, Any] | None): A dictionary representing the global context. preserve_global_context (bool): If True, merges the existing global context with the new one. scope (ContextScope | None): The named scope that should be initialized. @@ -481,7 +478,7 @@ def __init__( self._global_context = global_context self._context_token: Token[ContextType] | None = None self._context_items: typing.Final[set[SupportsContext[typing.Any]]] = set(context_items) - self._context_providers: set[ContextResource[typing.Any]] = set() + self._entered_context_items: set[SupportsContext[typing.Any]] = set() self._reset_resource_context: typing.Final[bool] = bool(scope) self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None self._scope_token: Token[ContextScope | None] | None = None @@ -508,25 +505,27 @@ def _resolve_initial_conditions(self) -> None: if isinstance(item, type) and issubclass(item, BaseContainer): self._add_providers_from_containers([item], self._scope) - elif isinstance(item, ContextResource): - self._context_providers.add(item) + else: + self._entered_context_items.add(item) def _add_providers_from_containers( - self, containers: Iterable[ContainerType], scope: ContextScope | None = ContextScopes.ANY + self, + containers: Iterable[BaseContainerType], + scope: ContextScope | None = ContextScopes.ANY, ) -> None: for container in containers: for container_provider in container.get_providers().values(): if isinstance(container_provider, ContextResource): provider_scope = container_provider.get_scope() if provider_scope in (scope, ContextScopes.ANY): - self._context_providers.add(container_provider) + self._entered_context_items.add(container_provider) @override def __enter__(self) -> ContextType: self._resolve_initial_conditions() self._context_stack = contextlib.ExitStack() self._scope_token = _set_current_scope(self._scope) - for item in self._context_providers: + for item in self._entered_context_items: if item.supports_context_sync(): self._context_stack.enter_context(item.context_sync()) return self._enter_globals() @@ -536,7 +535,7 @@ async def __aenter__(self) -> ContextType: self._resolve_initial_conditions() self._context_stack = contextlib.AsyncExitStack() self._scope_token = _set_current_scope(self._scope) - for item in self._context_providers: + for item in self._entered_context_items: await self._context_stack.enter_async_context(item.context_async()) return self._enter_globals() @@ -667,8 +666,7 @@ def __init__( Args: app (ASGIApp): The ASGI application to wrap. - *context_items (SupportsContext[Any]): A collection of containers and providers that - need context initialization prior to a request. + *context_items: Containers and providers that need context initialization prior to a request. global_context (dict[str, Any] | None): A global context dictionary to set before requests. scope (ContextScope | None): The scope in which the context should be initialized. diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index 6242ea5..6203b23 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -94,7 +94,7 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co]] = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs self._register_arguments() @@ -105,12 +105,8 @@ async def resolve(self) -> T_co: return typing.cast(T_co, self._override) return self._factory( - *[ # type: ignore[arg-type] - await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) @override @@ -119,12 +115,8 @@ def resolve_sync(self) -> T_co: return typing.cast(T_co, self._override) return self._factory( - *[ # type: ignore[arg-type] - x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) @@ -176,7 +168,7 @@ def __init__( """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co | typing.Awaitable[T_co]]] = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs self._register_arguments() @@ -190,8 +182,8 @@ async def resolve(self) -> T_co: kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()} result = self._factory( - *args, # type:ignore[arg-type] - **kwargs, # type:ignore[arg-type] + *args, + **kwargs, ) if inspect.isawaitable(result): diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index dd8f1ef..1a6fdd5 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -50,7 +50,7 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co]] = factory self._thread_local = threading.local() self._asyncio_lock = asyncio.Lock() self._args: typing.Final = args @@ -83,13 +83,12 @@ async def resolve(self) -> T_co: self._register_arguments() - self._instance = self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + instance = self._factory( + *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) - return self._instance + self._instance = instance + return instance @override def resolve_sync(self) -> T_co: @@ -101,11 +100,12 @@ def resolve_sync(self) -> T_co: self._register_arguments() - self._instance = self._factory( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + instance = self._factory( + *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) - return self._instance + self._instance = instance + return instance @override def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index daa470d..012f5e6 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -46,7 +46,7 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() - self._factory: typing.Final = factory + self._factory: typing.Final[typing.Callable[..., T_co]] = factory self._instance: T_co | None = None self._asyncio_lock: typing.Final = asyncio.Lock() self._threading_lock: typing.Final = threading.Lock() @@ -73,10 +73,8 @@ async def resolve(self) -> T_co: return self._instance self._register_arguments() self._instance = self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) return self._instance @@ -92,8 +90,8 @@ def resolve_sync(self) -> T_co: return self._instance self._register_arguments() self._instance = self._factory( - *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) return self._instance @@ -159,7 +157,7 @@ def __init__( """ super().__init__() - self._factory: typing.Final[typing.Callable[P, typing.Awaitable[T_co]]] = factory + self._factory: typing.Final[typing.Callable[..., typing.Awaitable[T_co]]] = factory self._instance: T_co | None = None self._asyncio_lock: typing.Final = asyncio.Lock() self._args: typing.Final = args @@ -186,10 +184,8 @@ async def resolve(self) -> T_co: self._register_arguments() self._instance = await self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] - k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) return self._instance From 0bfed716770232f6e1ae0b88e6be0a35bb54794c Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 21 Apr 2026 12:50:50 +0200 Subject: [PATCH 06/13] fix: fix faststream integration. --- Justfile | 4 ++-- pyproject.toml | 1 + that_depends/integrations/faststream.py | 13 ++++--------- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/Justfile b/Justfile index d28ab97..073c24d 100644 --- a/Justfile +++ b/Justfile @@ -8,13 +8,13 @@ install: lint: uv run ruff format uv run ruff check --fix - uv run mypy . + uv run mypy . --disable-error-code=unused-ignore uv run pyrefly check lint-ci: uv run ruff format --check uv run ruff check --no-fix - uv run mypy . + uv run mypy . --disable-error-code=unused-ignore uv run pyrefly check test *args: diff --git a/pyproject.toml b/pyproject.toml index 25857d5..077ab98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ module-root = "" python_version = "3.10" strict = true + [tool.ruff] fix = true unsafe-fixes = true diff --git a/that_depends/integrations/faststream.py b/that_depends/integrations/faststream.py index f576518..2120f5b 100644 --- a/that_depends/integrations/faststream.py +++ b/that_depends/integrations/faststream.py @@ -31,7 +31,7 @@ def __init__( """Initialize the container context middleware. Args: - *context_items: Context-capable providers or container classes to initialize. + *context_items (SupportsContext[Any]): Context-capable providers or container classes to initialize. msg (Any): Message object. context (ContextRepo): Context repository. global_context (dict[str, Any] | Unset): Global context to initialize the container. @@ -79,8 +79,8 @@ def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # return DIContextMiddleware( *self._context_items, - msg=msg, - context=context, + msg=msg, # type:ignore[unexpected-keyword] + context=context, # type:ignore[unexpected-keyword] scope=self._scope, global_context=self._global_context, ) @@ -94,22 +94,17 @@ class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef] def __init__( self, *context_items: SupportsContext[typing.Any], - msg: object | None = None, - context: object | None = None, global_context: dict[str, Any] | Unset = UNSET, scope: ContextScope | Unset = UNSET, ) -> None: """Initialize the container context middleware. Args: - *context_items: Context-capable providers or container classes to initialize. - msg: Message object passed by faststream. - context: Context repository passed by faststream. + *context_items (SupportsContext[Any]): Context-capable providers or container classes to initialize. global_context (dict[str, Any] | Unset): Global context to initialize the container. scope (ContextScope | Unset): Context scope to initialize the container. """ - del msg, context super().__init__() # type: ignore[call-arg] self._context: container_context | None = None self._context_items: set[SupportsContext[typing.Any]] = set(context_items) From e2e05395869566b51f0a1a8afbc9b18017d0233d Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 21 Apr 2026 13:11:17 +0200 Subject: [PATCH 07/13] fix: change from typing to typing extensions. --- Justfile | 4 ++-- that_depends/providers/base.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Justfile b/Justfile index 073c24d..1c25f06 100644 --- a/Justfile +++ b/Justfile @@ -9,13 +9,13 @@ lint: uv run ruff format uv run ruff check --fix uv run mypy . --disable-error-code=unused-ignore - uv run pyrefly check + uv run pyrefly check --no-progress-bar lint-ci: uv run ruff format --check uv run ruff check --no-fix uv run mypy . --disable-error-code=unused-ignore - uv run pyrefly check + uv run pyrefly check --no-progress-bar test *args: uv run --no-sync pytest {{ args }} diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index da03e05..8713001 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -504,7 +504,7 @@ async def resolve(self) -> T_co: context.instance = context.context_stack.enter_context(cm) else: # pragma: no cover - typing.assert_never(cm) + typing_extensions.assert_never(cm) return context.instance From c428bfa5408304ce9434dcd43a5a8850e8cb5d5a Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 24 Apr 2026 14:30:54 +0200 Subject: [PATCH 08/13] test: add tests for coverage. --- tests/providers/test_attr_getter.py | 1 + tests/providers/test_base.py | 46 ++++++++- tests/providers/test_context_resources.py | 27 +++++- tests/providers/test_local_singleton.py | 17 ++++ tests/providers/test_singleton.py | 10 ++ tests/test_injection.py | 111 +++++++++++++++++++++- 6 files changed, 207 insertions(+), 5 deletions(-) diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index b1292c6..c8ea886 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -171,6 +171,7 @@ class _Container(BaseContainer): attr_getter = _Container.child.v + attr_getter._register_arguments() attr_getter._register_arguments() assert attr_getter.resolve_sync() == _Container.parent.resolve_sync() diff --git a/tests/providers/test_base.py b/tests/providers/test_base.py index 1c8edbb..5b2084c 100644 --- a/tests/providers/test_base.py +++ b/tests/providers/test_base.py @@ -6,7 +6,7 @@ from typing_extensions import override from that_depends import BaseContainer -from that_depends.providers.base import AbstractProvider +from that_depends.providers.base import AbstractProvider, _resolve_arguments, _resolve_arguments_sync from that_depends.providers.local_singleton import ThreadLocalSingleton from that_depends.providers.mixin import SupportsTeardown from that_depends.providers.resources import Resource @@ -43,6 +43,18 @@ def resolve_sync(self) -> int: return self._instance # pragma: no cover +async def test_resolve_arguments_with_multiple_values() -> None: + provider = DummyProvider() + + assert await _resolve_arguments((provider, "value"), (True, False)) == [1, "value"] + + +def test_resolve_arguments_sync_with_multiple_values() -> None: + provider = DummyProvider() + + assert _resolve_arguments_sync((provider, "value"), (True, False)) == [1, "value"] + + def test_add_child_provider() -> None: provider_a = DummyProvider() provider_b = DummyProvider() @@ -75,6 +87,28 @@ def test_register_with_mixed_items() -> None: assert parent in child_1._children, "Expected child_1._children to contain parent" +def test_invalidate_scope_init_order_handles_duplicate_descendants() -> None: + root = DummyProvider() + left = DummyProvider() + right = DummyProvider() + shared = DummyProvider() + + root.add_child_provider(left) + root.add_child_provider(right) + left.add_child_provider(shared) + right.add_child_provider(shared) + + for provider in (root, left, right, shared): + provider._scope_context_init_order = () + provider._scope_init_order = () + + root._invalidate_scope_init_order() + + for provider in (root, left, right, shared): + assert provider._scope_context_init_order is None + assert provider._scope_init_order is None + + def test_sync_tear_down_propagation() -> None: parent = DummyProvider() child_1 = DummyProvider() @@ -168,6 +202,16 @@ def test_thread_local_singleton_registration_and_deregistration(dummy_singleton: assert thread_local not in dummy_singleton._children, "ThreadLocalSingleton should be deregistered after teardown." +def test_get_scope_init_order_is_cached_and_includes_parents(dummy_singleton: Singleton[int]) -> None: + singleton = Singleton(lambda value: value + 1, dummy_singleton.cast) + + first = singleton._get_scope_init_order() + second = singleton._get_scope_init_order() + + assert first == (dummy_singleton, singleton) + assert second is first + + def test_resource_registration_and_deregistration(dummy_singleton: Singleton[int]) -> None: resource = Resource(_resource_generator, dummy_singleton.cast) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 73832ad..774bbe0 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -591,9 +591,18 @@ def test_enter_sync_context_for_async_resource_should_throw( async_context_resource._enter_context_sync() -def test_exit_sync_context_before_enter_should_throw(sync_context_resource: providers.ContextResource[str]) -> None: - with pytest.raises(RuntimeError): - sync_context_resource._exit_context_sync() +def test_exit_sync_context_before_enter_should_throw() -> None: + provider = providers.ContextResource(create_sync_context_resource) + + with pytest.raises(RuntimeError, match=r"Context is not set, call ``_enter_sync_context`` first"): + provider._exit_context_sync() + + +def test_context_sync_manager_exit_before_enter_should_throw( + sync_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError, match=r"Context is not set, call ``__enter__`` first"): + sync_context_resource.context_sync().__exit__(None, None, None) async def test_exit_async_context_before_enter_should_throw( @@ -610,6 +619,18 @@ def test_enter_sync_context_from_async_resource_should_throw( stack.enter_context(async_context_resource.context_sync()) +def test_enter_and_exit_sync_context_directly(sync_context_resource: providers.ContextResource[str]) -> None: + context = sync_context_resource._enter_context_sync() + + assert isinstance(context, ResourceContext) + assert sync_context_resource.resolve_sync() is not None + + sync_context_resource._exit_context_sync() + + with pytest.raises(RuntimeError, match=r"Context is not set. Use container_context"): + sync_context_resource.resolve_sync() + + async def test_preserve_globals_and_initial_context() -> None: initial_context = {"test_1": "test_1", "test_2": "test_2"} diff --git a/tests/providers/test_local_singleton.py b/tests/providers/test_local_singleton.py index a614c95..8d18c92 100644 --- a/tests/providers/test_local_singleton.py +++ b/tests/providers/test_local_singleton.py @@ -4,6 +4,7 @@ import time import typing from concurrent.futures.thread import ThreadPoolExecutor +from unittest.mock import Mock import pytest @@ -51,6 +52,22 @@ async def test_async_thread_local_singleton_asyncio() -> None: assert provider._instance is None, "Tear down failed: Instance should be reset to None." +async def test_thread_local_singleton_reuses_instance_created_while_waiting_on_lock() -> None: + expected_value = 42 + factory = Mock(return_value=1) + provider = ThreadLocalSingleton(factory) + + await provider._asyncio_lock.acquire() + task = asyncio.create_task(provider.resolve()) + await asyncio.sleep(0) + + provider._instance = expected_value + provider._asyncio_lock.release() + + assert await task == expected_value + factory.assert_not_called() + + def test_thread_local_singleton_different_threads() -> None: """Test that different threads receive different instances.""" provider = ThreadLocalSingleton(_factory) diff --git a/tests/providers/test_singleton.py b/tests/providers/test_singleton.py index 7bba7bd..1d9f324 100644 --- a/tests/providers/test_singleton.py +++ b/tests/providers/test_singleton.py @@ -144,6 +144,16 @@ async def test_async_singleton_override() -> None: assert result == SingletonFactory(dep1="bar") +def test_async_singleton_register_arguments_is_idempotent() -> None: + parent = providers.Singleton(lambda: "foo") + singleton_async = providers.AsyncSingleton(create_async_obj, value=parent.cast) + + singleton_async._register_arguments() + singleton_async._register_arguments() + + assert singleton_async in parent._children + + async def test_async_singleton_asyncio_concurrency() -> None: singleton_async = providers.AsyncSingleton(create_async_obj, "foo") diff --git a/tests/test_injection.py b/tests/test_injection.py index 448f4ec..bd388ab 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -10,7 +10,14 @@ from tests import container from that_depends import BaseContainer, ContextScopes, Provide, container_context, get_current_scope, inject, providers -from that_depends.injection import ContextProviderError, StringProviderDefinition +from that_depends.injection import ( + _INJECT_DIRECT_PROVIDER, + ContextProviderError, + StringProviderDefinition, + _InjectionParameter, + _resolve_injected_provider, + _SyncInjectionStack, +) @pytest.fixture(name="fixture_one") @@ -55,6 +62,31 @@ async def inner( await inner(True, arg2=container.SimpleFactory(dep1="1", dep2=2)) +def test_sync_injection_stack_closes_entered_context_managers() -> None: + events: list[str] = [] + + @contextmanager + def _managed() -> typing.Iterator[str]: + events.append("enter") + try: + yield "value" + finally: + events.append("exit") + + with _SyncInjectionStack() as stack: + assert stack.enter_context(_managed()) == "value" + events.append("body") + + assert events == ["enter", "body", "exit"] + + +def test_resolve_injected_provider_with_direct_provider() -> None: + provider = providers.Object(1) + parameter = _InjectionParameter(0, "value", _INJECT_DIRECT_PROVIDER, provider) + + assert _resolve_injected_provider(parameter, None) is provider + + async def test_empty_injection() -> None: @inject async def inner(_: int) -> None: @@ -116,6 +148,15 @@ def inner_gen(_: int) -> typing.Generator[int, None, None]: next(inner_gen(1)) +def test_sync_empty_injection_without_scope_warns() -> None: + @inject(scope=None) + def inner(value: int) -> int: + return value + + with pytest.warns(RuntimeWarning, match=r"Expected injection, but nothing found. Remove @inject decorator."): + assert inner(1) == 1 + + def test_type_check() -> None: @inject async def main(simple_factory: container.SimpleFactory = Provide[container.DIContainer.simple_factory]) -> None: @@ -220,6 +261,34 @@ def _injected(val: int = Provide["_Container.sync_resource"]) -> int: assert _injected() == return_value +async def test_async_injection_with_string_provider_definition_respects_explicit_arguments() -> None: + override_value = 10 + + class _Container(BaseContainer): + async_resource = providers.AsyncFactory(_async_creator) + + @inject + async def _injected(val: int = Provide["_Container.async_resource"]) -> int: + return val + + assert await _injected(override_value) == override_value + assert await _injected(val=override_value) == override_value + + +def test_sync_injection_with_string_provider_definition_respects_explicit_arguments() -> None: + override_value = 10 + + class _Container(BaseContainer): + sync_resource = providers.Factory(lambda: 1) + + @inject + def _injected(val: int = Provide["_Container.sync_resource"]) -> int: + return val + + assert _injected(override_value) == override_value + assert _injected(val=override_value) == override_value + + def test_provider_string_definition_with_alias() -> None: return_value = 321 @@ -336,6 +405,46 @@ def _injected(v: float = Provide[_Container.provider_used]) -> float: assert _Container.provider_used.resolve_sync() +async def test_injection_with_string_provider_definition_initializes_parent_context_async() -> None: + async def _async_resource() -> typing.AsyncIterator[float]: + yield random.random() + + class _Container(BaseContainer): + provider_used = providers.ContextResource(_async_resource).with_config(scope=ContextScopes.INJECT) + dependent = providers.Factory(lambda value: value, provider_used.cast) + + @inject + async def _injected(val: float = Provide["_Container.dependent"]) -> float: + assert val == await _Container.dependent.resolve() + assert val == await _Container.provider_used.resolve() + return val + + assert isinstance(await _injected(), float) + + with pytest.raises(RuntimeError): + await _Container.provider_used.resolve() + + +def test_injection_with_string_provider_definition_initializes_parent_context_sync() -> None: + def _sync_resource() -> typing.Iterator[float]: + yield random.random() + + class _Container(BaseContainer): + provider_used = providers.ContextResource(_sync_resource).with_config(scope=ContextScopes.INJECT) + dependent = providers.Factory(lambda value: value, provider_used.cast) + + @inject + def _injected(val: float = Provide["_Container.dependent"]) -> float: + assert val == _Container.dependent.resolve_sync() + assert val == _Container.provider_used.resolve_sync() + return val + + assert isinstance(_injected(), float) + + with pytest.raises(RuntimeError): + _Container.provider_used.resolve_sync() + + async def test_injection_initializes_context_for_parents_async() -> None: async def _async_resource() -> typing.AsyncIterator[float]: yield random.random() From e6c998ab1884a7ce0f2d85113ab3e4784fc2bb8c Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 28 Apr 2026 23:12:03 +0200 Subject: [PATCH 09/13] feat: unset for overrides & instances. --- tests/providers/test_base.py | 7 ++-- tests/providers/test_context_resources.py | 5 +-- tests/providers/test_local_singleton.py | 5 +-- tests/providers/test_resources.py | 5 +-- tests/providers/test_singleton.py | 5 +-- tests/test_container.py | 23 ++++++------ tests/test_multiple_containers.py | 9 ++--- that_depends/entities/resource_context.py | 37 ++++++++++--------- that_depends/providers/base.py | 15 ++++---- that_depends/providers/context_resources.py | 13 +++---- that_depends/providers/factories.py | 7 ++-- that_depends/providers/local_singleton.py | 25 ++++++------- that_depends/providers/object.py | 3 +- that_depends/providers/selector.py | 5 +-- that_depends/providers/singleton.py | 39 +++++++++++---------- 15 files changed, 110 insertions(+), 93 deletions(-) diff --git a/tests/providers/test_base.py b/tests/providers/test_base.py index 5b2084c..7a0601c 100644 --- a/tests/providers/test_base.py +++ b/tests/providers/test_base.py @@ -11,6 +11,7 @@ from that_depends.providers.mixin import SupportsTeardown from that_depends.providers.resources import Resource from that_depends.providers.singleton import AsyncSingleton, Singleton +from that_depends.utils import is_set class DummyProvider(SupportsTeardown, AbstractProvider[int]): @@ -277,7 +278,7 @@ def test_propagate_off() -> None: parent.tear_down_sync(propagate=False) assert child in parent._children - assert child._instance is not None + assert is_set(child._instance) async def test_async_tear_down_propagation_with_singleton() -> None: @@ -288,7 +289,7 @@ async def test_async_tear_down_propagation_with_singleton() -> None: await parent.tear_down() - assert child._instance is None + assert not is_set(child._instance) async def test_async_propagate_off() -> None: @@ -299,7 +300,7 @@ async def test_async_propagate_off() -> None: await parent.tear_down(propagate=False) - assert child._instance is not None + assert is_set(child._instance) async def test_provider_registration_in_different_scope_async() -> None: diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 774bbe0..66405a0 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -24,6 +24,7 @@ from that_depends.meta import DefaultScopeNotDefinedError from that_depends.providers import DIContextMiddleware, container_context from that_depends.providers.context_resources import InvalidContextError, _enter_named_scope, fetch_context_item_by_type +from that_depends.utils import is_set logger = logging.getLogger(__name__) @@ -187,9 +188,9 @@ async def test_early_exit_of_container_context() -> None: async def test_resource_context_early_teardown() -> None: context: ResourceContext[str] = ResourceContext(is_async=True) assert context.is_async is True - assert context.context_stack is None + assert not is_set(context.context_stack) context.tear_down_sync() - assert context.context_stack is None + assert not is_set(context.context_stack) async def test_teardown_sync_container_context_with_async_resource() -> None: diff --git a/tests/providers/test_local_singleton.py b/tests/providers/test_local_singleton.py index 8d18c92..4f92dda 100644 --- a/tests/providers/test_local_singleton.py +++ b/tests/providers/test_local_singleton.py @@ -9,6 +9,7 @@ import pytest from that_depends.providers import AsyncFactory, ThreadLocalSingleton +from that_depends.utils import is_set random.seed(23) @@ -35,7 +36,7 @@ def test_thread_local_singleton_same_thread() -> None: provider.tear_down_sync() - assert provider._instance is None, "Tear down failed: Instance should be reset to None." + assert not is_set(provider._instance), "Tear down failed: Instance should be unset." async def test_async_thread_local_singleton_asyncio() -> None: @@ -49,7 +50,7 @@ async def test_async_thread_local_singleton_asyncio() -> None: await provider.tear_down() - assert provider._instance is None, "Tear down failed: Instance should be reset to None." + assert not is_set(provider._instance), "Tear down failed: Instance should be unset." async def test_thread_local_singleton_reuses_instance_created_while_waiting_on_lock() -> None: diff --git a/tests/providers/test_resources.py b/tests/providers/test_resources.py index a48fc95..239cd6d 100644 --- a/tests/providers/test_resources.py +++ b/tests/providers/test_resources.py @@ -14,6 +14,7 @@ create_sync_resource, ) from that_depends import BaseContainer, providers +from that_depends.utils import is_set logger = logging.getLogger(__name__) @@ -189,9 +190,9 @@ def create_resource() -> typing.Iterator[str]: def test_sync_resource_sync_tear_down() -> None: DIContainer.sync_resource.resolve_sync() - assert DIContainer.sync_resource._context.instance is not None + assert is_set(DIContainer.sync_resource._context.instance) DIContainer.sync_resource.tear_down_sync() - assert DIContainer.sync_resource._context.instance is None + assert not is_set(DIContainer.sync_resource._context.instance) async def test_async_resource_sync_tear_down_raises() -> None: diff --git a/tests/providers/test_singleton.py b/tests/providers/test_singleton.py index 1d9f324..3be5b91 100644 --- a/tests/providers/test_singleton.py +++ b/tests/providers/test_singleton.py @@ -9,6 +9,7 @@ import pytest from that_depends import BaseContainer, providers +from that_depends.utils import is_set @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) @@ -197,9 +198,9 @@ async def test_async_singleton_teardown() -> None: await singleton_async.resolve() singleton_async.tear_down_sync() - assert singleton_async._instance is None + assert not is_set(singleton_async._instance) await singleton_async.resolve() await singleton_async.tear_down() - assert singleton_async._instance is None + assert not is_set(singleton_async._instance) diff --git a/tests/test_container.py b/tests/test_container.py index f7f4752..fcc5ac0 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -3,6 +3,7 @@ from tests.container import DIContainer from that_depends import BaseContainer, providers +from that_depends.utils import is_set def _sync_resource() -> typing.Iterator[float]: @@ -15,11 +16,11 @@ async def test_container_sync_teardown() -> None: for provider in DIContainer.providers.values(): if isinstance(provider, providers.Resource): if provider._is_async: - assert provider._context.instance is not None + assert is_set(provider._context.instance) else: - assert provider._context.instance is None + assert not is_set(provider._context.instance) if isinstance(provider, providers.Singleton): - assert provider._instance is None + assert not is_set(provider._instance) async def test_container_tear_down() -> None: @@ -27,9 +28,9 @@ async def test_container_tear_down() -> None: await DIContainer.tear_down() for provider in DIContainer.providers.values(): if isinstance(provider, providers.Resource): - assert provider._context.instance is None + assert not is_set(provider._context.instance) if isinstance(provider, providers.Singleton): - assert provider._instance is None + assert not is_set(provider._instance) async def test_container_sync_tear_down_propagation() -> None: @@ -46,8 +47,8 @@ class _DependentContainer(BaseContainer): DIContainer.tear_down_sync() - assert _DependentContainer.singleton._instance is None - assert _DependentContainer.resource._context.instance is None + assert not is_set(_DependentContainer.singleton._instance) + assert not is_set(_DependentContainer.resource._context.instance) async def test_container_tear_down_propagation() -> None: @@ -74,7 +75,7 @@ class _DependentContainer(BaseContainer): await DIContainer.tear_down() - assert _DependentContainer.async_singleton._instance is None - assert _DependentContainer.async_resource._context.instance is None - assert _DependentContainer.sync_singleton._instance is None - assert _DependentContainer.sync_resource._context.instance is None + assert not is_set(_DependentContainer.async_singleton._instance) + assert not is_set(_DependentContainer.async_resource._context.instance) + assert not is_set(_DependentContainer.sync_singleton._instance) + assert not is_set(_DependentContainer.sync_resource._context.instance) diff --git a/tests/test_multiple_containers.py b/tests/test_multiple_containers.py index 6e74282..69a2246 100644 --- a/tests/test_multiple_containers.py +++ b/tests/test_multiple_containers.py @@ -4,6 +4,7 @@ from tests import container from that_depends import BaseContainer, providers +from that_depends.utils import is_set class InnerContainer(BaseContainer): @@ -23,16 +24,16 @@ async def test_included_container() -> None: assert all(isinstance(x, datetime.datetime) for x in sequence) await OuterContainer.tear_down() - assert InnerContainer.sync_resource._context.instance is None - assert InnerContainer.async_resource._context.instance is None + assert not is_set(InnerContainer.sync_resource._context.instance) + assert not is_set(InnerContainer.async_resource._context.instance) await OuterContainer.init_resources() sync_resource_context = InnerContainer.sync_resource._context assert sync_resource_context - assert sync_resource_context.instance is not None + assert is_set(sync_resource_context.instance) async_resource_context = InnerContainer.async_resource._context assert async_resource_context - assert async_resource_context.instance is not None + assert is_set(async_resource_context.instance) await OuterContainer.tear_down() diff --git a/that_depends/entities/resource_context.py b/that_depends/entities/resource_context.py index d94e5b4..725c0a9 100644 --- a/that_depends/entities/resource_context.py +++ b/that_depends/entities/resource_context.py @@ -7,6 +7,7 @@ from typing_extensions import override from that_depends.providers.mixin import CannotTearDownSyncError, SupportsTeardown +from that_depends.utils import UNSET, Unset, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -26,17 +27,17 @@ def __init__(self, is_async: bool) -> None: For example within a ``async with container_context(Container): ...`` statement. """ - self.instance: T_co | None = None + self.instance: T_co | Unset = UNSET self.asyncio_lock: typing.Final = asyncio.Lock() self.threading_lock: typing.Final = threading.Lock() - self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None + self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | Unset = UNSET self.is_async = is_async def set_context_state( self, *, - instance: T_co | None = None, - context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None, + instance: T_co | Unset = UNSET, + context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | Unset = UNSET, ) -> None: """Set the context state of the resource. @@ -53,14 +54,14 @@ def set_context_state( @staticmethod def is_context_stack_async( - context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None, + context_stack: object, ) -> typing.TypeGuard[contextlib.AsyncExitStack]: """Check if the context stack is an async context stack.""" return isinstance(context_stack, contextlib.AsyncExitStack) @staticmethod def is_context_stack_sync( - context_stack: contextlib.AsyncExitStack | contextlib.ExitStack, + context_stack: object, ) -> typing.TypeGuard[contextlib.ExitStack]: """Check if the context stack is a sync context stack.""" return isinstance(context_stack, contextlib.ExitStack) @@ -68,25 +69,27 @@ def is_context_stack_sync( @override async def tear_down(self, propagate: bool = True) -> None: """Tear down the async context stack.""" - if self.context_stack is None: + context_stack = self.context_stack + if not is_set(context_stack): return - if self.is_context_stack_async(self.context_stack): - await self.context_stack.aclose() - elif self.is_context_stack_sync(self.context_stack): - self.context_stack.close() - self.set_context_state(instance=None, context_stack=None) + if self.is_context_stack_async(context_stack): + await context_stack.aclose() + elif self.is_context_stack_sync(context_stack): + context_stack.close() + self.set_context_state(instance=UNSET, context_stack=UNSET) @override def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: """Tear down the sync context stack.""" - if self.context_stack is None: + context_stack = self.context_stack + if not is_set(context_stack): return - if self.is_context_stack_sync(self.context_stack): - self.context_stack.close() - self.set_context_state(instance=None, context_stack=None) - elif self.is_context_stack_async(self.context_stack): + if self.is_context_stack_sync(context_stack): + context_stack.close() + self.set_context_state(instance=UNSET, context_stack=UNSET) + elif self.is_context_stack_async(context_stack): msg = "Cannot tear down async context in sync mode" if raise_on_async: raise CannotTearDownSyncError(msg) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 8713001..7adc02d 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -11,6 +11,7 @@ from that_depends.entities.resource_context import ResourceContext from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown +from that_depends.utils import UNSET, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -96,7 +97,7 @@ def __init__(self) -> None: self._is_context_resource = False self._scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None self._scope_init_order: tuple[AbstractProvider[typing.Any], ...] | None = None - self._override: typing.Any = None + self._override: typing.Any = UNSET self._bindings: set[type] = set() self._has_contravariant_bindings: bool = False self._lock = threading.Lock() @@ -363,7 +364,7 @@ def reset_override_sync( None """ - self._override = None + self._override = UNSET if tear_down_children: eligible_children = [child for child in self._children if isinstance(child, SupportsTeardown)] for child in eligible_children: @@ -380,7 +381,7 @@ async def reset_override(self, tear_down_children: bool = False, propagate: bool None """ - self._override = None + self._override = UNSET if tear_down_children: eligible_children = [child for child in self._children if isinstance(child, SupportsTeardown)] for child in eligible_children: @@ -479,14 +480,14 @@ def _fetch_context(self) -> ResourceContext[T_co]: ... @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) context = self._fetch_context() # lock to prevent race condition while resolving async with context.asyncio_lock: - if context.instance is not None: + if is_set(context.instance): return context.instance self._register_arguments() @@ -510,14 +511,14 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) context = self._fetch_context() # lock to prevent race condition while resolving with context.threading_lock: - if context.instance is not None: + if is_set(context.instance): return context.instance if self._is_async: diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index c8f966e..b4363ee 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -14,6 +14,7 @@ from that_depends.entities.resource_context import ResourceContext from that_depends.providers.base import AbstractResource +from that_depends.utils import UNSET if typing.TYPE_CHECKING: @@ -55,10 +56,10 @@ def __init__( def close(self) -> None: context_stack = self._context_item.context_stack - if context_stack is not None: + if self._context_item.is_context_stack_sync(context_stack): context_stack.close() # type: ignore[union-attr] - self._context_item.context_stack = None - self._context_item.instance = None + self._context_item.context_stack = UNSET + self._context_item.instance = UNSET self._context.reset(self._token) @@ -444,10 +445,10 @@ def _exit_context_sync(self) -> None: try: context_item = self._context.get() context_stack = context_item.context_stack - if context_stack is not None: + if context_item.is_context_stack_sync(context_stack): context_stack.close() # type: ignore[union-attr] - context_item.context_stack = None - context_item.instance = None + context_item.context_stack = UNSET + context_item.instance = UNSET finally: self._context.reset(self._token) diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index c607794..9102d0d 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -13,6 +13,7 @@ _resolve_keyword_arguments_sync, ) from that_depends.providers.mixin import ProviderWithArguments +from that_depends.utils import is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -122,7 +123,7 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) return self._factory( @@ -132,7 +133,7 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) return self._factory( @@ -211,7 +212,7 @@ def __init__( @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) args = await _resolve_arguments(self._args, self._args_are_providers) diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index dc32106..20228f1 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -12,6 +12,7 @@ _resolve_keyword_arguments_sync, ) from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown +from that_depends.utils import UNSET, Unset, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -79,22 +80,22 @@ def _deregister_arguments(self) -> None: self._reset_arguments_registration() @property - def _instance(self) -> T_co | None: - return getattr(self._thread_local, "instance", None) + def _instance(self) -> T_co | Unset: + return typing.cast(T_co | Unset, getattr(self._thread_local, "instance", UNSET)) @_instance.setter - def _instance(self, value: T_co | None) -> None: + def _instance(self, value: T_co | Unset) -> None: self._thread_local.instance = value @override async def resolve(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) - if self._instance is not None: + if is_set(self._instance): return self._instance async with self._asyncio_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() @@ -108,10 +109,10 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() @@ -125,8 +126,8 @@ def resolve_sync(self) -> T_co: @override def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) @@ -138,8 +139,8 @@ async def tear_down(self, propagate: bool = True) -> None: After calling this method, subsequent calls to `resolve_sync()` on the same thread will produce a new instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: await self._tear_down_children() diff --git a/that_depends/providers/object.py b/that_depends/providers/object.py index b29ed7c..c0843ea 100644 --- a/that_depends/providers/object.py +++ b/that_depends/providers/object.py @@ -3,6 +3,7 @@ from typing_extensions import override from that_depends.providers.base import AbstractProvider +from that_depends.utils import is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -42,6 +43,6 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) return self._obj diff --git a/that_depends/providers/selector.py b/that_depends/providers/selector.py index 5bbe023..9ec2c48 100644 --- a/that_depends/providers/selector.py +++ b/that_depends/providers/selector.py @@ -5,6 +5,7 @@ from typing_extensions import override from that_depends.providers.base import AbstractProvider +from that_depends.utils import is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -69,7 +70,7 @@ def my_selector(): @override async def resolve(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) if isinstance(self._selector, AbstractProvider): @@ -83,7 +84,7 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override: + if is_set(self._override): return typing.cast(T_co, self._override) if isinstance(self._selector, AbstractProvider): diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index 2a1b29c..961f77a 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -14,6 +14,7 @@ _resolve_keyword_arguments_sync, ) from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown +from that_depends.utils import UNSET, Unset, is_set T_co = typing.TypeVar("T_co", covariant=True) @@ -64,7 +65,7 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() self._factory: typing.Final[typing.Callable[..., T_co]] = factory - self._instance: T_co | None = None + self._instance: T_co | Unset = UNSET self._asyncio_lock: typing.Final = asyncio.Lock() self._threading_lock: typing.Final = threading.Lock() self._args: typing.Final = args @@ -88,15 +89,15 @@ def _deregister_arguments(self) -> None: @override async def resolve(self) -> T_co: - if self._override is not None: + if is_set(self._override): self._register_arguments() return typing.cast(T_co, self._override) - if self._instance is not None: + if is_set(self._instance): return self._instance # lock to prevent resolving several times async with self._asyncio_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() self._instance = self._factory( @@ -107,15 +108,15 @@ async def resolve(self) -> T_co: @override def resolve_sync(self) -> T_co: - if self._override is not None: + if is_set(self._override): self._register_arguments() return typing.cast(T_co, self._override) - if self._instance is not None: + if is_set(self._instance): return self._instance # lock to prevent resolving several times with self._threading_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() self._instance = self._factory( @@ -130,8 +131,8 @@ async def tear_down(self, propagate: bool = True) -> None: After calling this method, the next resolve() call will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: await self._tear_down_children() @@ -142,8 +143,8 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> After calling this method, the next resolve call will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) @@ -197,7 +198,7 @@ def __init__( """ super().__init__() self._factory: typing.Final[typing.Callable[..., typing.Awaitable[T_co]]] = factory - self._instance: T_co | None = None + self._instance: T_co | Unset = UNSET self._asyncio_lock: typing.Final = asyncio.Lock() self._args: typing.Final = args self._kwargs: typing.Final = kwargs @@ -220,14 +221,14 @@ def _deregister_arguments(self) -> None: @override async def resolve(self) -> T_co: - if self._override is not None: + if is_set(self._override): return typing.cast(T_co, self._override) - if self._instance is not None: + if is_set(self._instance): return self._instance # lock to prevent resolving several times async with self._asyncio_lock: - if self._instance is not None: + if is_set(self._instance): return self._instance self._register_arguments() @@ -249,8 +250,8 @@ async def tear_down(self, propagate: bool = True) -> None: After calling this method, the next call to ``resolve()`` will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: await self._tear_down_children() @@ -261,8 +262,8 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> After calling this method, the next call to ``resolve_sync()`` will recreate the instance. """ - if self._instance is not None: - self._instance = None + if is_set(self._instance): + self._instance = UNSET self._deregister_arguments() if propagate: self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) From 30f8b35399ed6f748e890b5fe5122b97fb2627d0 Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 29 Apr 2026 10:39:42 +0200 Subject: [PATCH 10/13] cleanup: minor injection plan changes. --- tests/test_injection.py | 14 ++++++++++++-- that_depends/injection.py | 38 ++++++++++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/tests/test_injection.py b/tests/test_injection.py index bd388ab..859e5f3 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -11,9 +11,10 @@ from tests import container from that_depends import BaseContainer, ContextScopes, Provide, container_context, get_current_scope, inject, providers from that_depends.injection import ( - _INJECT_DIRECT_PROVIDER, ContextProviderError, StringProviderDefinition, + _build_injection_plan, + _InjectionKind, _InjectionParameter, _resolve_injected_provider, _SyncInjectionStack, @@ -82,11 +83,20 @@ def _managed() -> typing.Iterator[str]: def test_resolve_injected_provider_with_direct_provider() -> None: provider = providers.Object(1) - parameter = _InjectionParameter(0, "value", _INJECT_DIRECT_PROVIDER, provider) + parameter = _InjectionParameter(0, "value", _InjectionKind.DIRECT_PROVIDER, provider) assert _resolve_injected_provider(parameter, None) is provider +def test_build_injection_plan_stores_annotation_for_type_based_injection() -> None: + def _injected(value: float = Provide()) -> float: + return value + + plan = _build_injection_plan(_injected) + + assert plan.dynamic_parameters == (_InjectionParameter(0, "value", _InjectionKind.TYPED_PROVIDER, float),) + + async def test_empty_injection() -> None: @inject async def inner(_: int) -> None: diff --git a/that_depends/injection.py b/that_depends/injection.py index abd4edc..0d37023 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -1,3 +1,4 @@ +import enum import functools import inspect import re @@ -26,16 +27,26 @@ class ContextProviderError(Exception): _PROVIDE_MESSAGE: typing.Final[str] = ( "Use @Container.inject or @inject(container=Container) if you wish to use Provide()" ) -_INJECT_DIRECT_PROVIDER: typing.Final = 1 -_INJECT_STRING_PROVIDER: typing.Final = 2 -_INJECT_TYPED_PROVIDER: typing.Final = 3 + + +class _InjectionKind(enum.Enum): + DIRECT_PROVIDER = enum.auto() + STRING_PROVIDER = enum.auto() + TYPED_PROVIDER = enum.auto() + + +_InjectionDependency: typing.TypeAlias = typing.Union[ + AbstractProvider[typing.Any], + "StringProviderDefinition", + type[typing.Any], +] class _InjectionParameter(typing.NamedTuple): argument_index: int field_name: str - kind: int - dependency: typing.Any + kind: _InjectionKind + dependency: _InjectionDependency class _DirectInjectionParameter(typing.NamedTuple): @@ -100,7 +111,7 @@ def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionP for index, (field_name, param) in enumerate(inspect.signature(func).parameters.items()): default = param.default if isinstance(default, StringProviderDefinition): - dynamic_parameters.append(_InjectionParameter(index, field_name, _INJECT_STRING_PROVIDER, default)) + dynamic_parameters.append(_InjectionParameter(index, field_name, _InjectionKind.STRING_PROVIDER, default)) elif isinstance(default, AbstractProvider): direct_parameters.append( _DirectInjectionParameter( @@ -111,7 +122,14 @@ def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionP ) ) elif isinstance(default, _Provide): - dynamic_parameters.append(_InjectionParameter(index, field_name, _INJECT_TYPED_PROVIDER, param.annotation)) + dynamic_parameters.append( + _InjectionParameter( + index, + field_name, + _InjectionKind.TYPED_PROVIDER, + typing.cast(type[typing.Any], param.annotation), + ) + ) return _InjectionPlan( direct_parameters=tuple(direct_parameters), dynamic_parameters=tuple(dynamic_parameters), @@ -327,14 +345,14 @@ def _resolve_injected_provider( parameter: _InjectionParameter, container: BaseContainerMeta | None, ) -> AbstractProvider[typing.Any]: - if parameter.kind == _INJECT_DIRECT_PROVIDER: + if parameter.kind is _InjectionKind.DIRECT_PROVIDER: return typing.cast(AbstractProvider[typing.Any], parameter.dependency) - if parameter.kind == _INJECT_STRING_PROVIDER: + if parameter.kind is _InjectionKind.STRING_PROVIDER: string_definition = typing.cast(StringProviderDefinition, parameter.dependency) return string_definition.provider if container is None: raise RuntimeError(_PROVIDE_MESSAGE) - annotation = parameter.dependency + annotation = typing.cast(type[typing.Any], parameter.dependency) try: return container.get_provider_for_type(annotation) except TypeNotBoundError as e: From 9b90d7420b97abee78db39c3fe29e021e74e47c2 Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 29 Apr 2026 11:30:01 +0200 Subject: [PATCH 11/13] cleanup: refactor. --- tests/test_injection.py | 62 ++++++++++++++++++--- that_depends/injection.py | 112 +++++++++++++++++++++++--------------- 2 files changed, 122 insertions(+), 52 deletions(-) diff --git a/tests/test_injection.py b/tests/test_injection.py index 859e5f3..f4924f4 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -14,10 +14,9 @@ ContextProviderError, StringProviderDefinition, _build_injection_plan, - _InjectionKind, - _InjectionParameter, - _resolve_injected_provider, + _DirectInjectionParameter, _SyncInjectionStack, + _TypedInjectionParameter, ) @@ -81,11 +80,18 @@ def _managed() -> typing.Iterator[str]: assert events == ["enter", "body", "exit"] -def test_resolve_injected_provider_with_direct_provider() -> None: +def test_build_injection_plan_stores_direct_provider_separately() -> None: provider = providers.Object(1) - parameter = _InjectionParameter(0, "value", _InjectionKind.DIRECT_PROVIDER, provider) - assert _resolve_injected_provider(parameter, None) is provider + def _injected(value: providers.Object[int] = provider) -> providers.Object[int]: + return value + + plan = _build_injection_plan(_injected) + + assert _injected(provider) is provider + assert plan.direct_parameters == ( + _DirectInjectionParameter(0, "value", provider, provider._get_scope_context_init_order()), + ) def test_build_injection_plan_stores_annotation_for_type_based_injection() -> None: @@ -94,7 +100,21 @@ def _injected(value: float = Provide()) -> float: plan = _build_injection_plan(_injected) - assert plan.dynamic_parameters == (_InjectionParameter(0, "value", _InjectionKind.TYPED_PROVIDER, float),) + assert _injected(1.0) == 1.0 + assert plan.typed_parameters == (_TypedInjectionParameter(0, "value", float),) + + +def test_build_injection_plan_stores_string_provider_separately() -> None: + def _injected(value: int = Provide["Container.provider"]) -> int: + return value + + plan = _build_injection_plan(_injected) + + assert _injected(1) == 1 + assert len(plan.string_parameters) == 1 + assert plan.string_parameters[0].argument_index == 0 + assert plan.string_parameters[0].field_name == "value" + assert plan.string_parameters[0].definition._definition == "Container.provider" async def test_empty_injection() -> None: @@ -982,6 +1002,20 @@ def _injected_2(val: float = Provide()) -> float: assert isinstance(_injected_2(), float) +def test_injection_by_type_sync_respects_explicit_arguments() -> None: + class _Container(BaseContainer): + sync_resource = providers.Factory(random.random).bind(float) + + override_value = 10.0 + + @inject(container=_Container) + def _injected(val: float = Provide()) -> float: + return val + + assert _injected(override_value) == override_value + assert _injected(val=override_value) == override_value + + async def test_injection_by_type_async() -> None: class _Container(BaseContainer): sync_resource = providers.Factory(random.random).bind(float) @@ -998,6 +1032,20 @@ async def _injected_2(val: float = Provide()) -> float: assert isinstance(await _injected_2(), float) +async def test_injection_by_type_async_respects_explicit_arguments() -> None: + class _Container(BaseContainer): + sync_resource = providers.Factory(random.random).bind(float) + + override_value = 10.0 + + @inject(container=_Container) + async def _injected(val: float = Provide()) -> float: + return val + + assert await _injected(override_value) == override_value + assert await _injected(val=override_value) == override_value + + async def test_injection_by_type_async_generator() -> None: class _Container(BaseContainer): sync_resource = providers.Factory(random.random).bind(float) diff --git a/that_depends/injection.py b/that_depends/injection.py index 0d37023..75e70d6 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -1,4 +1,3 @@ -import enum import functools import inspect import re @@ -29,36 +28,29 @@ class ContextProviderError(Exception): ) -class _InjectionKind(enum.Enum): - DIRECT_PROVIDER = enum.auto() - STRING_PROVIDER = enum.auto() - TYPED_PROVIDER = enum.auto() - - -_InjectionDependency: typing.TypeAlias = typing.Union[ - AbstractProvider[typing.Any], - "StringProviderDefinition", - type[typing.Any], -] +class _DirectInjectionParameter(typing.NamedTuple): + argument_index: int + field_name: str + provider: AbstractProvider[typing.Any] + scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] -class _InjectionParameter(typing.NamedTuple): +class _StringInjectionParameter(typing.NamedTuple): argument_index: int field_name: str - kind: _InjectionKind - dependency: _InjectionDependency + definition: "StringProviderDefinition" -class _DirectInjectionParameter(typing.NamedTuple): +class _TypedInjectionParameter(typing.NamedTuple): argument_index: int field_name: str - provider: AbstractProvider[typing.Any] - scope_context_init_order: tuple[AbstractProvider[typing.Any], ...] + annotation: type[typing.Any] class _InjectionPlan(typing.NamedTuple): direct_parameters: tuple[_DirectInjectionParameter, ...] - dynamic_parameters: tuple[_InjectionParameter, ...] + string_parameters: tuple[_StringInjectionParameter, ...] + typed_parameters: tuple[_TypedInjectionParameter, ...] class _SyncInjectionStack: @@ -107,11 +99,12 @@ def close(self) -> None: @functools.cache def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionPlan: direct_parameters: list[_DirectInjectionParameter] = [] - dynamic_parameters: list[_InjectionParameter] = [] + string_parameters: list[_StringInjectionParameter] = [] + typed_parameters: list[_TypedInjectionParameter] = [] for index, (field_name, param) in enumerate(inspect.signature(func).parameters.items()): default = param.default if isinstance(default, StringProviderDefinition): - dynamic_parameters.append(_InjectionParameter(index, field_name, _InjectionKind.STRING_PROVIDER, default)) + string_parameters.append(_StringInjectionParameter(index, field_name, default)) elif isinstance(default, AbstractProvider): direct_parameters.append( _DirectInjectionParameter( @@ -122,17 +115,17 @@ def _build_injection_plan(func: typing.Callable[..., typing.Any]) -> _InjectionP ) ) elif isinstance(default, _Provide): - dynamic_parameters.append( - _InjectionParameter( + typed_parameters.append( + _TypedInjectionParameter( index, field_name, - _InjectionKind.TYPED_PROVIDER, typing.cast(type[typing.Any], param.annotation), ) ) return _InjectionPlan( direct_parameters=tuple(direct_parameters), - dynamic_parameters=tuple(dynamic_parameters), + string_parameters=tuple(string_parameters), + typed_parameters=tuple(typed_parameters), ) @@ -270,12 +263,12 @@ async def _resolve_arguments_async( *args: typing.Any, # noqa: ANN401 **kwargs: typing.Any, # noqa: ANN401 ) -> tuple[bool, dict[str, typing.Any]]: - if not plan.direct_parameters and not plan.dynamic_parameters: + if not _plan_has_injected_parameters(plan): return False, kwargs context_providers: set[AbstractProvider[typing.Any]] = set() for direct_parameter in plan.direct_parameters: - if direct_parameter.argument_index < len(args) or direct_parameter.field_name in kwargs: + if _is_argument_provided(direct_parameter.argument_index, direct_parameter.field_name, args, kwargs): continue if direct_parameter.scope_context_init_order: @@ -287,12 +280,23 @@ async def _resolve_arguments_async( ) kwargs[direct_parameter.field_name] = await direct_parameter.provider.resolve() - for dynamic_parameter in plan.dynamic_parameters: - if dynamic_parameter.argument_index < len(args) or dynamic_parameter.field_name in kwargs: + for string_parameter in plan.string_parameters: + if _is_argument_provided(string_parameter.argument_index, string_parameter.field_name, args, kwargs): + continue + + kwargs[string_parameter.field_name] = await _resolve_provider_with_scope_async( + string_parameter.definition.provider, + scope, + stack, + context_providers, + ) + + for typed_parameter in plan.typed_parameters: + if _is_argument_provided(typed_parameter.argument_index, typed_parameter.field_name, args, kwargs): continue - provider = _resolve_injected_provider(dynamic_parameter, container) - kwargs[dynamic_parameter.field_name] = await _resolve_provider_with_scope_async( + provider = _resolve_typed_provider(typed_parameter.annotation, container) + kwargs[typed_parameter.field_name] = await _resolve_provider_with_scope_async( provider, scope, stack, @@ -309,12 +313,12 @@ def _resolve_arguments_sync( *args: typing.Any, # noqa: ANN401 **kwargs: typing.Any, # noqa: ANN401 ) -> tuple[bool, dict[str, typing.Any]]: - if not plan.direct_parameters and not plan.dynamic_parameters: + if not _plan_has_injected_parameters(plan): return False, kwargs context_providers: set[AbstractProvider[typing.Any]] = set() for direct_parameter in plan.direct_parameters: - if direct_parameter.argument_index < len(args) or direct_parameter.field_name in kwargs: + if _is_argument_provided(direct_parameter.argument_index, direct_parameter.field_name, args, kwargs): continue if direct_parameter.scope_context_init_order: @@ -326,12 +330,23 @@ def _resolve_arguments_sync( ) kwargs[direct_parameter.field_name] = direct_parameter.provider.resolve_sync() - for dynamic_parameter in plan.dynamic_parameters: - if dynamic_parameter.argument_index < len(args) or dynamic_parameter.field_name in kwargs: + for string_parameter in plan.string_parameters: + if _is_argument_provided(string_parameter.argument_index, string_parameter.field_name, args, kwargs): + continue + + kwargs[string_parameter.field_name] = _resolve_provider_with_scope_sync( + string_parameter.definition.provider, + scope, + stack, + context_providers, + ) + + for typed_parameter in plan.typed_parameters: + if _is_argument_provided(typed_parameter.argument_index, typed_parameter.field_name, args, kwargs): continue - provider = _resolve_injected_provider(dynamic_parameter, container) - kwargs[dynamic_parameter.field_name] = _resolve_provider_with_scope_sync( + provider = _resolve_typed_provider(typed_parameter.annotation, container) + kwargs[typed_parameter.field_name] = _resolve_provider_with_scope_sync( provider, scope, stack, @@ -341,18 +356,25 @@ def _resolve_arguments_sync( return True, kwargs -def _resolve_injected_provider( - parameter: _InjectionParameter, +def _plan_has_injected_parameters(plan: _InjectionPlan) -> bool: + return bool(plan.direct_parameters or plan.string_parameters or plan.typed_parameters) + + +def _is_argument_provided( + argument_index: int, + field_name: str, + args: tuple[typing.Any, ...], + kwargs: dict[str, typing.Any], +) -> bool: + return argument_index < len(args) or field_name in kwargs + + +def _resolve_typed_provider( + annotation: type[typing.Any], container: BaseContainerMeta | None, ) -> AbstractProvider[typing.Any]: - if parameter.kind is _InjectionKind.DIRECT_PROVIDER: - return typing.cast(AbstractProvider[typing.Any], parameter.dependency) - if parameter.kind is _InjectionKind.STRING_PROVIDER: - string_definition = typing.cast(StringProviderDefinition, parameter.dependency) - return string_definition.provider if container is None: raise RuntimeError(_PROVIDE_MESSAGE) - annotation = typing.cast(type[typing.Any], parameter.dependency) try: return container.get_provider_for_type(annotation) except TypeNotBoundError as e: From 833950076eb3c30f03b9da93676acf100d646d5c Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 29 Apr 2026 11:40:23 +0200 Subject: [PATCH 12/13] docs: migration guide --- docs/migration/v4.md | 92 ++++++++++++++++++++++++++++++++++++++++ docs/overrides/main.html | 2 +- mkdocs.yml | 1 + 3 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 docs/migration/v4.md diff --git a/docs/migration/v4.md b/docs/migration/v4.md new file mode 100644 index 0000000..65cb932 --- /dev/null +++ b/docs/migration/v4.md @@ -0,0 +1,92 @@ +# Migrating from 3.\* to 4.\* + +## How to Read This Guide + +This guide is intended to help you migrate existing functionality from `that-depends` version `3.*` to `4.*`. +The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase. + +This migration intentionally focuses only on the **simplest change needed to preserve existing behaviour**. + +If you want to learn more about the new internals introduced in `4.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases). + +--- + +## Changes in the API + +### **Collection providers now return read-only container types** + +In `4.*`, collection providers no longer resolve to mutable built-in containers: + +- `providers.List(...)` now resolves to a `Sequence` implemented as a `tuple` +- `providers.Dict(...)` now resolves to a `Mapping` implemented as a `mappingproxy` + +If your existing code only reads from these values, you likely do not need to change anything. + +If your existing code mutates the resolved value, the **simplest way to preserve the old behaviour** is to convert the result at the call site: + +```python +items = list(MyContainer.items.resolve_sync()) +mapping = dict(MyContainer.mapping.resolve_sync()) +``` + +--- + +## Behaviour-Preserving Migration Examples + +### **`providers.List(...)`** + +Previously in `3.*`, code like this returned a `list`: + +```python +items = MyContainer.items.resolve_sync() +items.append("new-item") +``` + +In `4.*`, `items` is a read-only sequence, so mutating it directly will no longer work. + +To preserve the previous behaviour, change it to: + +```python +items = list(MyContainer.items.resolve_sync()) +items.append("new-item") +``` + +The same applies to async resolution: + +```python +items = list(await MyContainer.items.resolve()) +items.append("new-item") +``` + +--- + +### **`providers.Dict(...)`** + +Previously in `3.*`, code like this returned a mutable `dict`: + +```python +mapping = MyContainer.mapping.resolve_sync() +mapping["extra"] = "value" +``` + +In `4.*`, `mapping` is a read-only mapping, so item assignment will no longer work. + +To preserve the previous behaviour, change it to: + +```python +mapping = dict(MyContainer.mapping.resolve_sync()) +mapping["extra"] = "value" +``` + +The same applies to async resolution: + +```python +mapping = dict(await MyContainer.mapping.resolve()) +mapping["extra"] = "value" +``` + +--- + +## Further Help + +If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues). diff --git a/docs/overrides/main.html b/docs/overrides/main.html index 7496f3a..1f99571 100644 --- a/docs/overrides/main.html +++ b/docs/overrides/main.html @@ -1,5 +1,5 @@ {% extends "base.html" %} {% block announce %} - that-depends 3.0 just released! If you are upgrading, check out the migration guide. + that-depends 4.0 just released! If you are upgrading, check out the migration guide. {% endblock %} diff --git a/mkdocs.yml b/mkdocs.yml index 67e1bb7..416513c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,6 +37,7 @@ nav: - Migration: - 1.* to 2.*: migration/v2.md - 2.* to 3.*: migration/v3.md + - 3.* to 4.*: migration/v4.md - Development: - Contributing: dev/contributing.md - Decisions: dev/main-decisions.md From df6e2a7beadfd2994f89225a61b6f5644d690691 Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 29 Apr 2026 11:45:10 +0200 Subject: [PATCH 13/13] fix: address review comments. --- tests/providers/test_context_resources.py | 50 +++++++++++++++++++++ tests/test_injection.py | 39 +++++++++++++++- that_depends/injection.py | 4 +- that_depends/providers/context_resources.py | 18 ++++---- 4 files changed, 99 insertions(+), 12 deletions(-) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 66405a0..02a39c8 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -565,6 +565,56 @@ def _explicit_injected(val: str = Provide[sync_context_resource]) -> str: assert _explicit_injected() != _explicit_injected() +def test_sync_context_resource_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope) + + with _enter_named_scope(current_scope), resource.context_sync(): + assert isinstance(resource.resolve_sync(), str) + + +async def test_async_context_resource_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_async_context_resource).with_config(configured_scope) + + with _enter_named_scope(current_scope): + async with resource.context_async(): + assert isinstance(await resource.resolve(), str) + + +async def test_async_context_resource_strict_scope_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_async_context_resource).with_config(configured_scope, strict_scope=True) + + with _enter_named_scope(current_scope): + async with resource.context_async(force=True): + assert isinstance(await resource.resolve(), str) + + +def test_sync_context_resource_strict_scope_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope, strict_scope=True) + + with _enter_named_scope(current_scope), resource.context_sync(force=True): + assert isinstance(resource.resolve_sync(), str) + + +def test_container_context_accepts_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + current_scope = ContextScope("MATCHING_SCOPE") + + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + sync_context_resource = providers.ContextResource(create_sync_context_resource).with_config(configured_scope) + + with container_context(_Container, scope=current_scope): + assert isinstance(_Container.sync_context_resource.resolve_sync(), str) + + async def test_async_context_resource_with_dependent_container() -> None: """Container should initialize async context resource for dependent containers.""" async with DIContainer.context_async(): diff --git a/tests/test_injection.py b/tests/test_injection.py index f4924f4..ac4f3d5 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -9,7 +9,16 @@ import pytest from tests import container -from that_depends import BaseContainer, ContextScopes, Provide, container_context, get_current_scope, inject, providers +from that_depends import ( + BaseContainer, + ContextScope, + ContextScopes, + Provide, + container_context, + get_current_scope, + inject, + providers, +) from that_depends.injection import ( ContextProviderError, StringProviderDefinition, @@ -211,6 +220,20 @@ async def _injected(val: int = Provide[_Container.async_resource]) -> int: await inject(scope=ContextScopes.REQUEST)(_injected)() +async def test_async_injection_with_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + injection_scope = ContextScope("MATCHING_SCOPE") + + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + async_resource = providers.ContextResource(_async_creator).with_config(scope=configured_scope) + + async def _injected(val: int = Provide[_Container.async_resource]) -> int: + return val + + assert await inject(scope=injection_scope)(_injected)() == 1 + + async def test_sync_injection_with_scope() -> None: class _Container(BaseContainer): default_scope = ContextScopes.ANY @@ -227,6 +250,20 @@ def _injected(val: int = Provide[_Container.p_inject]) -> int: inject(scope=ContextScopes.REQUEST)(_injected)() +def test_sync_injection_with_equal_scope_instance() -> None: + configured_scope = ContextScope("MATCHING_SCOPE") + injection_scope = ContextScope("MATCHING_SCOPE") + + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + p_inject = providers.ContextResource(_sync_creator).with_config(scope=configured_scope) + + def _injected(val: int = Provide[_Container.p_inject]) -> int: + return val + + assert inject(scope=injection_scope)(_injected)() == 1 + + def test_inject_decorator_should_not_allow_any_scope() -> None: with pytest.raises(ValueError, match=f"{ContextScopes.ANY} is not allowed in inject decorator."): inject(scope=ContextScopes.ANY) diff --git a/that_depends/injection.py b/that_depends/injection.py index 75e70d6..bd020b7 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -468,7 +468,7 @@ async def _setup_scope_contexts_async( continue providers.add(provider) provider_scope = provider._scope # noqa: SLF001 - if provider_scope is ContextScopes.ANY or provider_scope is scope: + if provider_scope in (ContextScopes.ANY, scope): if stack is None: msg = ( f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" @@ -503,7 +503,7 @@ def _setup_scope_contexts_sync( continue providers.add(provider) provider_scope = provider._scope # noqa: SLF001 - if provider_scope is ContextScopes.ANY or provider_scope is scope: + if provider_scope in (ContextScopes.ANY, scope): if stack is None: msg = ( f"No stack exists, cannot initialize context for {provider} using scope {scope}.\n" diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index b4363ee..549b3cd 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -280,20 +280,20 @@ class ContextResource( @override async def resolve(self) -> T_co: - if not self._strict_scope or self._scope is ContextScopes.ANY: + if not self._strict_scope or self._scope == ContextScopes.ANY: return await super().resolve() current_scope = get_current_scope() - if self._scope is current_scope: + if self._scope == current_scope: return await super().resolve() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @override def resolve_sync(self) -> T_co: - if not self._strict_scope or self._scope is ContextScopes.ANY: + if not self._strict_scope or self._scope == ContextScopes.ANY: return super().resolve_sync() current_scope = get_current_scope() - if self._scope is current_scope: + if self._scope == current_scope: return super().resolve_sync() msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" raise RuntimeError(msg) @@ -414,9 +414,9 @@ def _enter_injection_context_sync( if self._is_async: msg = "Please use async context instead." raise RuntimeError(msg) - if not force and self._scope is not ContextScopes.ANY: + if not force and self._scope != ContextScopes.ANY: current_scope = get_current_scope() - if self._scope is not current_scope: + if self._scope != current_scope: msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" raise InvalidContextError(msg) @@ -428,9 +428,9 @@ async def _enter_context_async(self, force: bool = False) -> ResourceContext[T_c return self._enter(force) def _enter(self, force: bool = False) -> ResourceContext[T_co]: - if not force and self._scope is not ContextScopes.ANY: + if not force and self._scope != ContextScopes.ANY: current_scope = get_current_scope() - if self._scope is not current_scope: + if self._scope != current_scope: msg = f"Cannot enter context for resource with scope {self._scope} in scope {current_scope!r}" raise InvalidContextError(msg) context_item: ResourceContext[T_co] = ResourceContext(is_async=self._is_async) @@ -620,7 +620,7 @@ def _add_providers_from_containers( for container_provider in container.get_providers().values(): if isinstance(container_provider, ContextResource): provider_scope = container_provider.get_scope() - if provider_scope is scope or provider_scope is ContextScopes.ANY: + if provider_scope in (scope, ContextScopes.ANY): target.add(container_provider) @override