diff --git a/openfeature/_event_support.py b/openfeature/_event_support.py index 3928be3e..41fc28f0 100644 --- a/openfeature/_event_support.py +++ b/openfeature/_event_support.py @@ -61,6 +61,13 @@ def add_client_handler( handlers = _client_handlers[client][event] handlers.append(handler) + # outside the lock intentionally: the immediate-fire status check acquires the registry lock, so calling it + # under _client_lock risks lock-order inversion against run_handlers_for_provider (registry lock → _client_lock). + # As a consequence, a narrow double-fire is possible: if dispatch_event(client's event) runs concurrently, it + # sets the matching provider status (enabling the immediate fire below) and then re-runs every handler for this + # client. If _run_immediate_handler lands after that status set but before dispatch snapshots the handler list, + # the handler fires twice — once here, once from dispatch. Only happens when the registered event matches the event + # being dispatched; otherwise the immediate fire is a no-op. _run_immediate_handler(client, event, handler) @@ -78,6 +85,7 @@ def add_global_handler(event: ProviderEvent, handler: EventHandler) -> None: from openfeature.api import get_client # noqa: PLC0415 + # See comment in add_client_handler for why this runs outside the lock. _run_immediate_handler(get_client(), event, handler) @@ -134,6 +142,6 @@ def _run_handler(handler: EventHandler, details: EventDetails) -> None: def clear() -> None: with _global_lock: - _global_handlers.clear() - with _client_lock: - _client_handlers.clear() + with _client_lock: + _global_handlers.clear() + _client_handlers.clear() diff --git a/openfeature/api.py b/openfeature/api.py index 4585e50e..80fa2434 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -62,7 +62,6 @@ def set_provider_and_wait(provider: FeatureProvider, domain: str | None = None) def clear_providers() -> None: provider_registry.clear_providers() - _event_support.clear() def get_provider_metadata(domain: str | None = None) -> Metadata: diff --git a/openfeature/client.py b/openfeature/client.py index 95dc5b6d..9bf7f513 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -1,4 +1,5 @@ import logging +import threading import typing from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass @@ -86,6 +87,7 @@ def __init__( self.version = version self.context = context or EvaluationContext() self.hooks = hooks or [] + self._hooks_lock = threading.Lock() @property def provider(self) -> FeatureProvider: @@ -98,7 +100,10 @@ def get_metadata(self) -> ClientMetadata: return ClientMetadata(domain=self.domain) def add_hooks(self, hooks: list[Hook]) -> None: - self.hooks = self.hooks + hooks + # Guards the read-concat-store against a lost update; this practically never races under the default 5ms GIL + # switch interval, but is essential under a no-GIL build. + with self._hooks_lock: + self.hooks = self.hooks + hooks def get_boolean_value( self, @@ -468,8 +473,9 @@ def _establish_hooks_and_provider( def _assert_provider_status( self, + provider: FeatureProvider, ) -> OpenFeatureError | None: - status = self.get_provider_status() + status = provider_registry.get_provider_status(provider) if status == ProviderStatus.NOT_READY: return ProviderNotReadyError() if status == ProviderStatus.FATAL: @@ -589,7 +595,7 @@ async def evaluate_flag_details_async( ) try: - if provider_err := self._assert_provider_status(): + if provider_err := self._assert_provider_status(provider): error_hooks( flag_type, provider_err, @@ -765,7 +771,7 @@ def evaluate_flag_details( ) try: - if provider_err := self._assert_provider_status(): + if provider_err := self._assert_provider_status(provider): error_hooks( flag_type, provider_err, diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index 247d316b..b0f142cc 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading import typing from collections.abc import Mapping, MutableMapping, Sequence from datetime import datetime @@ -24,6 +25,7 @@ ] _hooks: list[Hook] = [] +_hooks_lock = threading.Lock() # https://openfeature.dev/specification/sections/hooks/#requirement-461 @@ -150,15 +152,20 @@ def supports_flag_value_type(self, flag_type: FlagType) -> bool: """ return True +# while the lock guarantees safety, even without it there was never a loss within 50.000 runs (with the default GIL +# switch interval of 5ms). only when the switch interval was significantly shortened to 0.1 microseconds, losses were +# observed without locks every now and then. with a no-GIL python, the lock would be essential def add_hooks(hooks: list[Hook]) -> None: - global _hooks - _hooks = _hooks + hooks + with _hooks_lock: + global _hooks + _hooks = _hooks + hooks def clear_hooks() -> None: - global _hooks - _hooks = [] + with _hooks_lock: + global _hooks + _hooks = [] def get_hooks() -> list[Hook]: diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 1b2b5206..e02aab78 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -261,5 +261,6 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_STALE, details) def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: - if hasattr(self, "_on_emit"): - self._on_emit(self, event, details) + on_emit = getattr(self, "_on_emit", None) + if on_emit is not None: + on_emit(self, event, details) diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index e46caadd..63fa6ce3 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -1,6 +1,6 @@ import threading -from openfeature._event_support import run_handlers_for_provider +from openfeature._event_support import run_handlers_for_provider, clear as clear_event_handlers from openfeature.evaluation_context import EvaluationContext, get_evaluation_context from openfeature.event import ( ProviderEvent, @@ -54,9 +54,16 @@ def set_provider( self._shutdown_if_unused(old_provider) def get_provider(self, domain: str | None) -> FeatureProvider: - if domain is None: - return self._default_provider - return self._providers.get(domain, self._default_provider) + # defensive lock under the GIL as the op is basically atomic + # but we might want to keep it so a provider that's about + # to be shut down isn't returned + # however it contributes to a potential deadlock that is currently + # still in place (clear_providers: registry's lock -> _event_support's lock; + # run_handlers_for_provider: _event_support's lock -> registry's lock) + with self._lock: + if domain is None: + return self._default_provider + return self._providers.get(domain, self._default_provider) def set_default_provider( self, provider: FeatureProvider, wait_for_init: bool = False @@ -83,7 +90,8 @@ def set_default_provider( self._shutdown_if_unused(old_provider) def get_default_provider(self) -> FeatureProvider: - return self._default_provider + with self._lock: + return self._default_provider def clear_providers(self) -> None: self.shutdown() @@ -93,11 +101,13 @@ def clear_providers(self) -> None: self._provider_status = { self._default_provider: ProviderStatus.READY, } + clear_event_handlers() def shutdown(self) -> None: with self._lock: providers = {self._default_provider, *self._providers.values()} + # do we want to move this inside the lock? it allows a narrow double-shutdown window for provider in providers: self._shutdown_provider(provider) @@ -214,7 +224,12 @@ def _shutdown_provider( provider.detach() def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: - return self._provider_status.get(provider, ProviderStatus.NOT_READY) + # defensive lock under the GIL as the op is basically atomic + # but we might want to keep it so a provider that's about + # to be shut down isn't returned + # however, removing it would enable moving _run_immediate_handler into the lock i think + with self._lock: + return self._provider_status.get(provider, ProviderStatus.NOT_READY) def dispatch_event( self, diff --git a/openfeature/transaction_context/__init__.py b/openfeature/transaction_context/__init__.py index 15ac7e01..e44f6135 100644 --- a/openfeature/transaction_context/__init__.py +++ b/openfeature/transaction_context/__init__.py @@ -1,3 +1,5 @@ +import threading + from openfeature.evaluation_context import EvaluationContext from openfeature.transaction_context.context_var_transaction_context_propagator import ( ContextVarsTransactionContextPropagator, @@ -21,13 +23,15 @@ _evaluation_transaction_context_propagator: TransactionContextPropagator = ( NoOpTransactionContextPropagator() ) +_propagator_lock = threading.Lock() def set_transaction_context_propagator( transaction_context_propagator: TransactionContextPropagator, ) -> None: global _evaluation_transaction_context_propagator - _evaluation_transaction_context_propagator = transaction_context_propagator + with _propagator_lock: + _evaluation_transaction_context_propagator = transaction_context_propagator def clear_transaction_context_propagator() -> None: @@ -35,11 +39,12 @@ def clear_transaction_context_propagator() -> None: def get_transaction_context() -> EvaluationContext: - return _evaluation_transaction_context_propagator.get_transaction_context() + with _propagator_lock: + propagator = _evaluation_transaction_context_propagator + return propagator.get_transaction_context() def set_transaction_context(evaluation_context: EvaluationContext) -> None: - global _evaluation_transaction_context_propagator - _evaluation_transaction_context_propagator.set_transaction_context( - evaluation_context - ) + with _propagator_lock: + propagator = _evaluation_transaction_context_propagator + propagator.set_transaction_context(evaluation_context) diff --git a/tests/test_client.py b/tests/test_client.py index 44b49e5f..64598936 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,11 +4,12 @@ import types import uuid from concurrent.futures import ThreadPoolExecutor -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest from openfeature import _event_support, api +from openfeature import client as client_module from openfeature.api import ( add_hooks, clear_hooks, @@ -20,7 +21,7 @@ from openfeature.client import OpenFeatureClient, _typecheck_flag_value from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails -from openfeature.exception import ErrorCode, OpenFeatureError +from openfeature.exception import ErrorCode, OpenFeatureError, ProviderFatalError from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason from openfeature.hook import Hook from openfeature.provider import FeatureProvider, ProviderStatus @@ -291,9 +292,10 @@ def test_provider_should_return_error_status_if_failed(): async def test_should_shortcircuit_if_provider_is_not_ready( no_op_provider_client, monkeypatch ): - # Given monkeypatch.setattr( - no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY + provider_registry, + "get_provider_status", + lambda provider: ProviderStatus.NOT_READY, ) spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) @@ -321,9 +323,10 @@ async def test_should_shortcircuit_if_provider_is_not_ready( async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( no_op_provider_client, monkeypatch ): - # Given monkeypatch.setattr( - no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL + provider_registry, + "get_provider_status", + lambda provider: ProviderStatus.FATAL, ) spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) @@ -768,3 +771,32 @@ def test_should_noop_if_provider_does_not_support_tracking(monkeypatch): set_provider(provider) client = get_client() client.track(tracking_event_name="test") + + +def test_assert_provider_status_uses_passed_provider_not_current_registry_state(): + fatal_provider = NoOpProvider() + ready_provider = NoOpProvider() + + registry_mock = Mock() + registry_mock.get_provider_status.side_effect = lambda p: ( + ProviderStatus.FATAL if p is fatal_provider else ProviderStatus.READY + ) + registry_mock.get_provider.return_value = ready_provider + + original = client_module.provider_registry + client_module.provider_registry = registry_mock + try: + c = OpenFeatureClient(domain=None, version=None) + assert c.provider is ready_provider, ( + "test setup: self.provider should resolve via the patched registry" + ) + + err = c._assert_provider_status(fatal_provider) + assert isinstance(err, ProviderFatalError), ( + "status check used self.provider (READY) instead of the captured " + "fatal_provider — TOCTOU regression" + ) + registry_mock.get_provider_status.assert_any_call(fatal_provider) + assert c._assert_provider_status(ready_provider) is None + finally: + client_module.provider_registry = original