diff --git a/modern_di/providers/context_provider.py b/modern_di/providers/context_provider.py index 2b56d67..aa1f903 100644 --- a/modern_di/providers/context_provider.py +++ b/modern_di/providers/context_provider.py @@ -20,7 +20,7 @@ class ContextProvider(AbstractProvider[types.T_co]): ``ArgumentResolutionError``. """ - __slots__ = ("_context_type",) + __slots__ = ("context_type",) def __init__( self, @@ -32,10 +32,12 @@ def __init__( super().__init__( scope=scope, bound_type=context_type if isinstance(bound_type, types.UnsetType) else bound_type ) - self._context_type = context_type + # Public, like its sibling ``bound_type`` — the type this provider supplies and + # the key its value is set under in ``context``. + self.context_type = context_type def __repr__(self) -> str: - return f"ContextProvider(context_type={self._context_type!r}, scope={self.scope!r})" + return f"ContextProvider(context_type={self.context_type!r}, scope={self.scope!r})" def resolve(self, container: "Container") -> types.T_co | None: value = self.fetch_context_value(container) @@ -47,4 +49,4 @@ def fetch_context_value(self, container: "Container") -> types.T_co | object: container = container.find_container(self.scope) if container.closed: raise exceptions.ContainerClosedError(container_scope=container.scope) - return container.context_registry.find_context(self._context_type) + return container.context_registry.find_context(self.context_type) diff --git a/tests/providers/test_context_provider.py b/tests/providers/test_context_provider.py index a17c96a..01d5b05 100644 --- a/tests/providers/test_context_provider.py +++ b/tests/providers/test_context_provider.py @@ -66,6 +66,11 @@ def test_context_provider_repr() -> None: assert repr(provider) == "ContextProvider(context_type=, scope=)" +def test_context_provider_exposes_context_type() -> None: + provider = providers.ContextProvider(context_type=datetime.datetime, scope=Scope.REQUEST) + assert provider.context_type is datetime.datetime + + @pytest.mark.parametrize("value", [0, False, "", [], {}, 0.0]) def test_context_provider_returns_falsy_values(value: object) -> None: context_type = type(value)