diff --git a/openfeature/client.py b/openfeature/client.py index a02693c1..d01ee56b 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -429,6 +429,11 @@ def _establish_hooks_and_provider( client_metadata = self.get_metadata() provider_metadata = provider.get_metadata() + provider_hooks = ( + [] + if self._provider_uses_internal_hooks(provider) + else provider.get_provider_hooks() + ) # Hooks need to be handled in different orders at different stages # in the flag evaluation @@ -450,7 +455,7 @@ def _establish_hooks_and_provider( get_hooks(), self.hooks, evaluation_hooks, - provider.get_provider_hooks(), + provider_hooks, ) ] # after, error, finally: Provider, Invocation, Client, API @@ -465,6 +470,36 @@ def _establish_hooks_and_provider( merged_eval_context, ) + def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool: + uses_internal_hooks = getattr(provider, "uses_internal_provider_hooks", None) + return bool(callable(uses_internal_hooks) and uses_internal_hooks()) + + def _set_internal_provider_hook_runtime( + self, + provider: FeatureProvider, + flag_type: FlagType, + hook_hints: HookHints, + ) -> object | None: + if not self._provider_uses_internal_hooks(provider): + return None + set_hook_runtime = getattr(provider, "set_internal_provider_hook_runtime", None) + if not callable(set_hook_runtime): + return None + return set_hook_runtime( + flag_type=flag_type, + client_metadata=self.get_metadata(), + hook_hints=hook_hints, + ) + + def _reset_internal_provider_hook_runtime( + self, provider: FeatureProvider, runtime_token: object | None + ) -> None: + if runtime_token is None: + return + reset_hook_runtime = getattr(provider, "reset_internal_provider_hook_runtime", None) + if callable(reset_hook_runtime): + reset_hook_runtime(runtime_token) + def _assert_provider_status( self, ) -> OpenFeatureError | None: @@ -611,13 +646,21 @@ async def evaluate_flag_details_async( merged_eval_context, ) - flag_evaluation = await self._create_provider_evaluation_async( + runtime_token = self._set_internal_provider_hook_runtime( provider, flag_type, - flag_key, - default_value, - merged_context, + hook_hints, ) + try: + flag_evaluation = await self._create_provider_evaluation_async( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + finally: + self._reset_internal_provider_hook_runtime(provider, runtime_token) if err := flag_evaluation.get_exception(): error_hooks( flag_type, err, reversed_merged_hooks_and_context, hook_hints @@ -787,13 +830,21 @@ def evaluate_flag_details( merged_eval_context, ) - flag_evaluation = self._create_provider_evaluation( + runtime_token = self._set_internal_provider_hook_runtime( provider, flag_type, - flag_key, - default_value, - merged_context, + hook_hints, ) + try: + flag_evaluation = self._create_provider_evaluation( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + finally: + self._reset_internal_provider_hook_runtime(provider, runtime_token) if err := flag_evaluation.get_exception(): error_hooks( flag_type, err, reversed_merged_hooks_and_context, hook_hints diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index aea5069f..b022bbc7 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -15,7 +15,18 @@ if typing.TYPE_CHECKING: from openfeature.flag_evaluation import FlagValueType -__all__ = ["AbstractProvider", "FeatureProvider", "Metadata", "ProviderStatus"] +__all__ = [ + "AbstractProvider", + "ComparisonStrategy", + "EvaluationStrategy", + "FeatureProvider", + "FirstMatchStrategy", + "FirstSuccessfulStrategy", + "Metadata", + "MultiProvider", + "ProviderEntry", + "ProviderStatus", +] class ProviderStatus(Enum): @@ -247,3 +258,13 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): self._on_emit(self, event, details) + + +from .multi_provider import ( # noqa: E402 + ComparisonStrategy, + EvaluationStrategy, + FirstMatchStrategy, + FirstSuccessfulStrategy, + MultiProvider, + ProviderEntry, +) diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index bf8fa9a8..1944dd9c 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -80,23 +80,25 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: try: if hasattr(provider, "initialize"): provider.initialize(self._get_evaluation_context()) - self.dispatch_event( - provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() - ) + if self.get_provider_status(provider) == ProviderStatus.NOT_READY: + self.dispatch_event( + provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() + ) except Exception as err: error_code = ( err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL ) - self.dispatch_event( - provider, - ProviderEvent.PROVIDER_ERROR, - ProviderEventDetails( - message=f"Provider initialization failed: {err}", - error_code=error_code, - ), - ) + if self.get_provider_status(provider) == ProviderStatus.NOT_READY: + self.dispatch_event( + provider, + ProviderEvent.PROVIDER_ERROR, + ProviderEventDetails( + message=f"Provider initialization failed: {err}", + error_code=error_code, + ), + ) def _shutdown_provider(self, provider: FeatureProvider) -> None: try: @@ -115,6 +117,11 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: provider.detach() def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: + provider_status_getter = getattr(provider, "get_status", None) + if callable(provider_status_getter): + status = provider_status_getter() + if isinstance(status, ProviderStatus): + return status return self._provider_status.get(provider, ProviderStatus.NOT_READY) def dispatch_event( diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py new file mode 100644 index 00000000..daff1cdf --- /dev/null +++ b/openfeature/provider/multi_provider.py @@ -0,0 +1,1192 @@ +from __future__ import annotations + +import asyncio +import contextvars +import logging +import threading +import typing +from collections.abc import Awaitable, Callable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass + +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError +from openfeature.flag_evaluation import ( + FlagEvaluationDetails, + FlagResolutionDetails, + FlagType, + FlagValueType, + Reason, +) +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.hook._hook_support import ( + after_all_hooks, + after_hooks, + before_hooks, + error_hooks, +) +from openfeature.provider import ( + AbstractProvider, + FeatureProvider, + Metadata, + ProviderStatus, +) + +__all__ = [ + "ComparisonStrategy", + "EvaluationStrategy", + "FirstMatchStrategy", + "FirstSuccessfulStrategy", + "MultiProvider", + "ProviderEntry", +] + +logger = logging.getLogger("openfeature") + +T = typing.TypeVar("T", bound=FlagValueType) +RunMode: typing.TypeAlias = typing.Literal["sequential", "parallel"] +ComparisonMismatchHandler: typing.TypeAlias = Callable[ + [str, Mapping[str, FlagResolutionDetails[FlagValueType]]], None +] + + +@dataclass(frozen=True) +class ProviderEntry: + provider: FeatureProvider + name: str | None = None + + +@dataclass(frozen=True) +class _ProviderEvaluation(typing.Generic[T]): + provider_name: str + provider: FeatureProvider + result: FlagResolutionDetails[T] + + +@dataclass(frozen=True) +class _ProviderHookRuntime: + flag_type: FlagType + client_metadata: typing.Any + hook_hints: HookHints + + +class EvaluationStrategy(typing.Protocol): + run_mode: RunMode + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: ... + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: ... + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: ... + + +def _is_success(result: FlagResolutionDetails[FlagValueType]) -> bool: + return result.error_code is None and result.reason != Reason.ERROR + + +def _validate_run_mode(run_mode: RunMode) -> RunMode: + if run_mode not in ("sequential", "parallel"): + raise ValueError(f"Unsupported run_mode '{run_mode}'") + return run_mode + + +def _format_result_error( + provider_name: str, result: FlagResolutionDetails[FlagValueType] +) -> str: + error_code = result.error_code.value if result.error_code else ErrorCode.GENERAL.value + error_message = result.error_message or "Unknown error" + return f"{provider_name}: {error_code} ({error_message})" + + +def _build_aggregated_error( + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + prefix: str, +) -> FlagResolutionDetails[FlagValueType]: + if not evaluations: + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=f"{prefix} for flag '{flag_key}': no providers returned a result", + ) + + errors_text = "; ".join( + _format_result_error(evaluation.provider_name, evaluation.result) + for evaluation in evaluations + ) + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=f"{prefix} for flag '{flag_key}': {errors_text}", + ) + + +class FirstMatchStrategy: + def __init__(self, run_mode: RunMode = "sequential") -> None: + self.run_mode = _validate_run_mode(run_mode) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return _is_success(result) + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return result.error_code == ErrorCode.FLAG_NOT_FOUND + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + for evaluation in evaluations: + if self.should_use_result( + flag_key, evaluation.provider_name, evaluation.result + ): + return evaluation.result + if not self.should_continue( + flag_key, evaluation.provider_name, evaluation.result + ): + return evaluation.result + if evaluations: + return evaluations[-1].result + return _build_aggregated_error( + flag_key, + default_value, + evaluations, + "Multi-provider evaluation failed", + ) + + +class FirstSuccessfulStrategy: + def __init__(self, run_mode: RunMode = "sequential") -> None: + self.run_mode = _validate_run_mode(run_mode) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return _is_success(result) + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return True + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + for evaluation in evaluations: + if _is_success(evaluation.result): + return evaluation.result + return _build_aggregated_error( + flag_key, + default_value, + evaluations, + "All providers failed", + ) + + +class ComparisonStrategy: + run_mode: RunMode = "parallel" + + def __init__( + self, + fallback_provider: str | None = None, + on_mismatch: ComparisonMismatchHandler | None = None, + ) -> None: + self.fallback_provider = fallback_provider + self.on_mismatch = on_mismatch + + def validate_provider_names(self, provider_names: Sequence[str]) -> None: + if ( + self.fallback_provider is not None + and self.fallback_provider not in provider_names + ): + raise ValueError( + f"Fallback provider '{self.fallback_provider}' is not registered" + ) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return False + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return True + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + failed_evaluations = [ + evaluation for evaluation in evaluations if not _is_success(evaluation.result) + ] + if failed_evaluations: + return _build_aggregated_error( + flag_key, + default_value, + failed_evaluations, + "Comparison strategy received provider errors", + ) + + fallback_evaluation = self._select_fallback_evaluation(evaluations) + fallback_value = fallback_evaluation.result.value + has_mismatch = any( + evaluation.result.value != fallback_value for evaluation in evaluations + ) + if has_mismatch and self.on_mismatch is not None: + mismatch_results = { + evaluation.provider_name: evaluation.result for evaluation in evaluations + } + try: + self.on_mismatch(flag_key, mismatch_results) + except Exception: + logger.exception( + "Comparison strategy mismatch callback failed for flag '%s'", + flag_key, + ) + return fallback_evaluation.result + + def _select_fallback_evaluation( + self, evaluations: list[_ProviderEvaluation[FlagValueType]] + ) -> _ProviderEvaluation[FlagValueType]: + if not evaluations: + raise ValueError("ComparisonStrategy requires at least one provider") + if self.fallback_provider is None: + return evaluations[0] + for evaluation in evaluations: + if evaluation.provider_name == self.fallback_provider: + return evaluation + raise ValueError( + f"Fallback provider '{self.fallback_provider}' is not registered" + ) + + +class MultiProvider(AbstractProvider): + _status_precedence: tuple[ProviderStatus, ...] = ( + ProviderStatus.FATAL, + ProviderStatus.NOT_READY, + ProviderStatus.ERROR, + ProviderStatus.STALE, + ProviderStatus.READY, + ) + + def __init__( + self, + providers: list[ProviderEntry], + strategy: EvaluationStrategy | None = None, + ) -> None: + super().__init__() + if not providers: + raise ValueError("At least one provider must be provided") + + self.strategy = strategy or FirstMatchStrategy() + self._registeredProviders: list[tuple[str, FeatureProvider]] = [] + self._provider_names: dict[FeatureProvider, str] = {} + self._provider_statuses: dict[str, ProviderStatus] = {} + self._aggregate_status = ProviderStatus.NOT_READY + self._statusLock = threading.Lock() + self._hookRuntime: contextvars.ContextVar[_ProviderHookRuntime | None] = ( + contextvars.ContextVar( + f"multiProviderHookRuntime:{id(self)}", + default=None, + ) + ) + self._register_providers(providers) + self._provider_statuses = { + provider_name: ProviderStatus.NOT_READY + for provider_name, _ in self._registeredProviders + } + validate_provider_names = getattr(self.strategy, "validate_provider_names", None) + if callable(validate_provider_names): + validate_provider_names( + [provider_name for provider_name, _ in self._registeredProviders] + ) + + def uses_internal_provider_hooks(self) -> bool: + return True + + def set_internal_provider_hook_runtime( + self, + flag_type: FlagType, + client_metadata: typing.Any, + hook_hints: HookHints, + ) -> contextvars.Token[_ProviderHookRuntime | None]: + return self._hookRuntime.set( + _ProviderHookRuntime( + flag_type=flag_type, + client_metadata=client_metadata, + hook_hints=hook_hints, + ) + ) + + def reset_internal_provider_hook_runtime( + self, token: contextvars.Token[_ProviderHookRuntime | None] + ) -> None: + self._hookRuntime.reset(token) + + def get_status(self) -> ProviderStatus: + with self._statusLock: + return self._aggregate_status + + def _register_providers(self, providers: list[ProviderEntry]) -> None: + name_counts: dict[str, int] = {} + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + name_counts[metadata_name] = name_counts.get(metadata_name, 0) + 1 + + used_names: set[str] = set() + name_indexes: dict[str, int] = {} + + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + if entry.name: + if entry.name in used_names: + raise ValueError(f"Provider name '{entry.name}' is not unique") + provider_name = entry.name + elif name_counts[metadata_name] == 1 and metadata_name not in used_names: + provider_name = metadata_name + else: + while True: + name_indexes[metadata_name] = name_indexes.get(metadata_name, 0) + 1 + provider_name = f"{metadata_name}_{name_indexes[metadata_name]}" + if provider_name not in used_names: + break + + used_names.add(provider_name) + self._registeredProviders.append((provider_name, entry.provider)) + self._provider_names[entry.provider] = provider_name + + def get_metadata(self) -> Metadata: + return Metadata(name="MultiProvider") + + def get_provider_hooks(self) -> list[Hook]: + return [] + + def attach( + self, + on_emit: Callable[[FeatureProvider, ProviderEvent, ProviderEventDetails], None], + ) -> None: + super().attach(on_emit) + for _, provider in self._registeredProviders: + provider.attach(self._handle_provider_event) + + def detach(self) -> None: + for _, provider in self._registeredProviders: + provider.detach() + super().detach() + + def initialize(self, evaluation_context: EvaluationContext) -> None: + def initialize_provider( + entry: tuple[str, FeatureProvider], + ) -> tuple[str, Exception | None]: + provider_name, provider = entry + try: + provider.initialize(evaluation_context) + return provider_name, None + except Exception as err: + return provider_name, err + + with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + init_results = list(executor.map(initialize_provider, self._registeredProviders)) + + error_messages: list[str] = [] + event_details = ProviderEventDetails() + for provider_name, err in init_results: + if err is None: + self._mark_provider_ready(provider_name) + continue + provider_status = self._status_from_exception(err) + self._set_provider_status(provider_name, provider_status) + error_messages.append( + f"Provider '{provider_name}' initialization failed: {self._error_message_from_exception(err)}" + ) + event_details = self._details_from_exception(err, provider_name) + + self._refresh_aggregate_status(event_details) + + if error_messages: + raise GeneralError(f"Multi-provider initialization failed: {'; '.join(error_messages)}") + + def shutdown(self) -> None: + for _, provider in self._registeredProviders: + provider.detach() + + def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: + provider_name, provider = entry + try: + provider.shutdown() + except Exception: + logger.exception("Provider '%s' shutdown failed", provider_name) + + with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + list(executor.map(shutdown_provider, self._registeredProviders)) + + with self._statusLock: + self._provider_statuses = { + provider_name: ProviderStatus.NOT_READY + for provider_name, _ in self._registeredProviders + } + self._aggregate_status = ProviderStatus.NOT_READY + + def _handle_provider_event( + self, + provider: FeatureProvider, + event: ProviderEvent, + details: ProviderEventDetails, + ) -> None: + provider_name = self._provider_names.get(provider) + if provider_name is None: + return + if event == ProviderEvent.PROVIDER_CONFIGURATION_CHANGED: + self.emit( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + self._with_provider_metadata(details, provider_name), + ) + return + if event == ProviderEvent.PROVIDER_READY: + self._set_provider_status(provider_name, ProviderStatus.READY) + elif event == ProviderEvent.PROVIDER_STALE: + self._set_provider_status(provider_name, ProviderStatus.STALE) + elif event == ProviderEvent.PROVIDER_ERROR: + self._set_provider_status( + provider_name, + self._status_from_event_details(details), + ) + self._refresh_aggregate_status(self._with_provider_metadata(details, provider_name)) + + def _set_provider_status( + self, provider_name: str, provider_status: ProviderStatus + ) -> None: + with self._statusLock: + self._provider_statuses[provider_name] = provider_status + + def _mark_provider_ready(self, provider_name: str) -> None: + with self._statusLock: + if self._provider_statuses.get(provider_name) == ProviderStatus.NOT_READY: + self._provider_statuses[provider_name] = ProviderStatus.READY + + def _calculate_aggregate_status(self) -> ProviderStatus: + statuses = tuple(self._provider_statuses.values()) + if not statuses: + return ProviderStatus.NOT_READY + for status in self._status_precedence: + if status in statuses: + return status + return ProviderStatus.NOT_READY + + def _refresh_aggregate_status(self, details: ProviderEventDetails) -> None: + event_to_emit: ProviderEvent | None = None + event_details = details + with self._statusLock: + previous_status = self._aggregate_status + aggregate_status = self._calculate_aggregate_status() + if previous_status == aggregate_status: + return + self._aggregate_status = aggregate_status + event_to_emit = self._event_from_status(aggregate_status) + event_details = self._details_for_status(aggregate_status, details) + if event_to_emit is not None: + self.emit(event_to_emit, event_details) + + def _event_from_status(self, provider_status: ProviderStatus) -> ProviderEvent | None: + if provider_status == ProviderStatus.READY: + return ProviderEvent.PROVIDER_READY + if provider_status == ProviderStatus.STALE: + return ProviderEvent.PROVIDER_STALE + if provider_status in (ProviderStatus.ERROR, ProviderStatus.FATAL): + return ProviderEvent.PROVIDER_ERROR + return None + + def _details_for_status( + self, provider_status: ProviderStatus, details: ProviderEventDetails + ) -> ProviderEventDetails: + error_code = details.error_code + if provider_status == ProviderStatus.FATAL: + error_code = ErrorCode.PROVIDER_FATAL + elif provider_status == ProviderStatus.ERROR and error_code is None: + error_code = ErrorCode.GENERAL + return ProviderEventDetails( + flags_changed=details.flags_changed, + message=details.message, + error_code=error_code, + metadata=dict(details.metadata), + ) + + def _with_provider_metadata( + self, details: ProviderEventDetails, provider_name: str + ) -> ProviderEventDetails: + metadata = dict(details.metadata) + metadata["provider_name"] = provider_name + return ProviderEventDetails( + flags_changed=details.flags_changed, + message=details.message, + error_code=details.error_code, + metadata=metadata, + ) + + def _status_from_event_details( + self, details: ProviderEventDetails + ) -> ProviderStatus: + if details.error_code == ErrorCode.PROVIDER_FATAL: + return ProviderStatus.FATAL + return ProviderStatus.ERROR + + def _status_from_exception(self, err: Exception) -> ProviderStatus: + if ( + isinstance(err, OpenFeatureError) + and err.error_code == ErrorCode.PROVIDER_FATAL + ): + return ProviderStatus.FATAL + return ProviderStatus.ERROR + + def _details_from_exception( + self, err: Exception, provider_name: str + ) -> ProviderEventDetails: + error_code = ( + err.error_code + if isinstance(err, OpenFeatureError) + else ErrorCode.GENERAL + ) + error_message = self._error_message_from_exception(err) + return ProviderEventDetails( + message=f"Provider '{provider_name}' failed: {error_message}", + error_code=error_code, + metadata={"provider_name": provider_name}, + ) + + def _error_message_from_exception(self, err: Exception) -> str: + if isinstance(err, OpenFeatureError) and err.error_message: + return err.error_message + return str(err) + + def _resolution_from_exception( + self, default_value: T, err: Exception + ) -> FlagResolutionDetails[T]: + error_code = ( + err.error_code + if isinstance(err, OpenFeatureError) + else ErrorCode.GENERAL + ) + error_message = self._error_message_from_exception(err) + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=error_code, + error_message=error_message, + ) + + def _create_provider_hook_contexts( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: FlagValueType, + evaluation_context: EvaluationContext, + client_metadata: typing.Any, + ) -> list[tuple[Hook, HookContext]]: + provider_metadata = provider.get_metadata() + return [ + ( + hook, + HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=evaluation_context, + client_metadata=client_metadata, + provider_metadata=provider_metadata, + hook_data={}, + ), + ) + for hook in provider.get_provider_hooks() + ] + + def _evaluate_provider_sync( # noqa: PLR0913 + self, + provider_name: str, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + FlagResolutionDetails[T], + ], + ) -> _ProviderEvaluation[T]: + runtime = self._hookRuntime.get() + if runtime is None or not provider.get_provider_hooks(): + try: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolve_fn(provider, flag_key, default_value, evaluation_context), + ) + except Exception as err: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + + provider_context = evaluation_context or EvaluationContext() + hook_contexts = self._create_provider_hook_contexts( + provider, + flag_type, + flag_key, + default_value, + provider_context, + runtime.client_metadata, + ) + reversed_hook_contexts = list(reversed(hook_contexts)) + flag_evaluation = FlagEvaluationDetails(flag_key=flag_key, value=default_value) + try: + before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) + resolved_context = provider_context.merge(before_context) + resolution = resolve_fn(provider, flag_key, default_value, resolved_context) + flag_evaluation = resolution.to_flag_evaluation_details(flag_key) + if err := flag_evaluation.get_exception(): + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + after_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + except Exception as err: + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + finally: + after_all_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + + async def _evaluate_provider_async( # noqa: PLR0913 + self, + provider_name: str, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + Awaitable[FlagResolutionDetails[T]], + ], + ) -> _ProviderEvaluation[T]: + runtime = self._hookRuntime.get() + if runtime is None or not provider.get_provider_hooks(): + try: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=await resolve_fn( + provider, flag_key, default_value, evaluation_context + ), + ) + except Exception as err: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + + provider_context = evaluation_context or EvaluationContext() + hook_contexts = self._create_provider_hook_contexts( + provider, + flag_type, + flag_key, + default_value, + provider_context, + runtime.client_metadata, + ) + reversed_hook_contexts = list(reversed(hook_contexts)) + flag_evaluation = FlagEvaluationDetails(flag_key=flag_key, value=default_value) + try: + before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) + resolved_context = provider_context.merge(before_context) + resolution = await resolve_fn(provider, flag_key, default_value, resolved_context) + flag_evaluation = resolution.to_flag_evaluation_details(flag_key) + if err := flag_evaluation.get_exception(): + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + after_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + except Exception as err: + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + finally: + after_all_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + + def _evaluate_with_providers( + self, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + FlagResolutionDetails[T], + ], + ) -> FlagResolutionDetails[T]: + if self.strategy.run_mode == "parallel": + with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + futures = [ + executor.submit( + self._evaluate_provider_sync, + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + for provider_name, provider in self._registeredProviders + ] + evaluations = [future.result() for future in futures] + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + evaluations: list[_ProviderEvaluation[T]] = [] + for provider_name, provider in self._registeredProviders: + evaluation = self._evaluate_provider_sync( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + evaluations.append(evaluation) + if self.strategy.should_use_result( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + return evaluation.result + if not self.strategy.should_continue( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + break + + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + async def _evaluate_with_providers_async( + self, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + Awaitable[FlagResolutionDetails[T]], + ], + ) -> FlagResolutionDetails[T]: + if self.strategy.run_mode == "parallel": + tasks = [ + asyncio.create_task( + self._evaluate_provider_async( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + ) + for provider_name, provider in self._registeredProviders + ] + evaluations = await asyncio.gather(*tasks) + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + list(evaluations), + ), + ), + ) + + evaluations: list[_ProviderEvaluation[T]] = [] + for provider_name, provider in self._registeredProviders: + evaluation = await self._evaluate_provider_async( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + evaluations.append(evaluation) + if self.strategy.should_use_result( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + return evaluation.result + if not self.strategy.should_continue( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + break + + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return self._evaluate_with_providers( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_boolean_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return await self._evaluate_with_providers_async( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_boolean_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return self._evaluate_with_providers( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_string_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return await self._evaluate_with_providers_async( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_string_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return self._evaluate_with_providers( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_integer_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return await self._evaluate_with_providers_async( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_integer_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return self._evaluate_with_providers( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_float_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return await self._evaluate_with_providers_async( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_float_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_object_details( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return self._evaluate_with_providers( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_object_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return await self._evaluate_with_providers_async( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_object_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) diff --git a/tests/test_multi_provider.py b/tests/test_multi_provider.py new file mode 100644 index 00000000..aa6cea26 --- /dev/null +++ b/tests/test_multi_provider.py @@ -0,0 +1,607 @@ +import asyncio +import threading +from unittest.mock import MagicMock + +import pytest + +from openfeature import api +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, GeneralError +from openfeature.flag_evaluation import ( + FlagEvaluationDetails, + FlagResolutionDetails, + Reason, +) +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.provider import ( + AbstractProvider, + ComparisonStrategy, + FirstMatchStrategy, + FirstSuccessfulStrategy, + Metadata, + MultiProvider, + ProviderEntry, + ProviderStatus, +) + + +class BooleanProvider(AbstractProvider): + def __init__( + self, + name: str, + boolean_result: FlagResolutionDetails[bool] | None = None, + boolean_exception: Exception | None = None, + hook_list: list[Hook] | None = None, + sync_blocker: "SyncBlocker | None" = None, + async_blocker: "AsyncBlocker | None" = None, + ) -> None: + super().__init__() + self.name = name + self.booleanResult = boolean_result + self.booleanException = boolean_exception + self.hookList = hook_list or [] + self.sync_blocker = sync_blocker + self.async_blocker = async_blocker + self.resolveCount = 0 + self.seenContexts: list[dict[str, object]] = [] + + def get_metadata(self) -> Metadata: + return Metadata(name=self.name) + + def get_provider_hooks(self) -> list[Hook]: + return self.hookList + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + del flag_key + self.resolveCount += 1 + self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes)) + if self.sync_blocker is not None: + self.sync_blocker.wait() + if self.booleanException is not None: + raise self.booleanException + if self.booleanResult is not None: + return self.booleanResult + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + del flag_key + self.resolveCount += 1 + self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes)) + if self.async_blocker is not None: + await self.async_blocker.wait() + if self.booleanException is not None: + raise self.booleanException + if self.booleanResult is not None: + return self.booleanResult + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_object_details( + self, + flag_key: str, + default_value: dict | list, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[dict | list]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + +class RecordingHook(Hook): + def __init__(self, hook_name: str) -> None: + self.hookName = hook_name + self.events: list[str] = [] + + def before( + self, hook_context: HookContext, hints: HookHints + ) -> EvaluationContext | None: + del hook_context + del hints + self.events.append("before") + return EvaluationContext(attributes={"hookOwner": self.hookName}) + + def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[object], + hints: HookHints, + ) -> None: + del hook_context + del details + del hints + self.events.append("after") + + def error( + self, hook_context: HookContext, exception: Exception, hints: HookHints + ) -> None: + del hook_context + del exception + del hints + self.events.append("error") + + def finally_after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[object], + hints: HookHints, + ) -> None: + del hook_context + del details + del hints + self.events.append("finally") + + +class SyncBlocker: + def __init__(self, expected_count: int) -> None: + self.expectedCount = expected_count + self.enteredCount = 0 + self.enteredEvent = threading.Event() + self.releaseEvent = threading.Event() + self.lock = threading.Lock() + + def wait(self) -> None: + with self.lock: + self.enteredCount += 1 + if self.enteredCount == self.expectedCount: + self.enteredEvent.set() + assert self.releaseEvent.wait(timeout=2) + + +class AsyncBlocker: + def __init__(self, expected_count: int) -> None: + self.expectedCount = expected_count + self.enteredCount = 0 + self.enteredEvent = asyncio.Event() + self.releaseEvent = asyncio.Event() + self.lock = asyncio.Lock() + + async def wait(self) -> None: + async with self.lock: + self.enteredCount += 1 + if self.enteredCount == self.expectedCount: + self.enteredEvent.set() + await asyncio.wait_for(self.releaseEvent.wait(), timeout=2) + + +def test_multi_provider_requires_at_least_one_provider(): + with pytest.raises(ValueError, match="At least one provider must be provided"): + MultiProvider([]) + + +def test_multi_provider_rejects_duplicate_explicit_names(): + first_provider = BooleanProvider("provider") + second_provider = BooleanProvider("provider") + + with pytest.raises(ValueError, match="Provider name 'duplicate' is not unique"): + MultiProvider( + [ + ProviderEntry(first_provider, name="duplicate"), + ProviderEntry(second_provider, name="duplicate"), + ] + ) + + +def test_comparison_strategy_rejects_unknown_fallback_provider(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + + with pytest.raises(ValueError, match="Fallback provider 'missing' is not registered"): + MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy(fallback_provider="missing"), + ) + + +def test_first_match_uses_fallback_after_flag_not_found(): + missing_result = FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.FLAG_NOT_FOUND, + error_message="missing", + ) + first_provider = BooleanProvider("first", boolean_result=missing_result) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_first_match_stops_on_non_flag_not_found_error(): + error_result = FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message="boom", + ) + first_provider = BooleanProvider("first", boolean_result=error_result) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert second_provider.resolveCount == 0 + + +def test_first_successful_skips_general_errors(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("broken")) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_first_successful_aggregates_errors_when_all_providers_fail(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) + second_provider = BooleanProvider("second", boolean_exception=GeneralError("second")) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert "first: GENERAL (first)" in result.error_message + assert "second: GENERAL (second)" in result.error_message + + +def test_comparison_strategy_returns_fallback_value_and_calls_on_mismatch(): + mismatch_spy = MagicMock() + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy( + fallback_provider="second", + on_mismatch=mismatch_spy, + ), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + mismatch_spy.assert_called_once() + + +def test_comparison_strategy_aggregates_provider_errors(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert "first: GENERAL (first)" in result.error_message + + +def test_multi_provider_runs_sync_parallel_evaluation(): + sync_blocker = SyncBlocker(expected_count=2) + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + sync_blocker=sync_blocker, + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + sync_blocker=sync_blocker, + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(run_mode="parallel"), + ) + + result_holder: list[FlagResolutionDetails[bool]] = [] + + def evaluate() -> None: + result_holder.append(multi_provider.resolve_boolean_details("flagKey", False)) + + worker_thread = threading.Thread(target=evaluate) + worker_thread.start() + + assert sync_blocker.enteredEvent.wait(timeout=2) + sync_blocker.releaseEvent.set() + worker_thread.join(timeout=2) + + assert result_holder[0].value is False + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +@pytest.mark.asyncio +async def test_multi_provider_runs_async_parallel_evaluation(): + async_blocker = AsyncBlocker(expected_count=2) + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + async_blocker=async_blocker, + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + async_blocker=async_blocker, + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(run_mode="parallel"), + ) + + evaluation_task = asyncio.create_task( + multi_provider.resolve_boolean_details_async("flagKey", False) + ) + + await asyncio.wait_for(async_blocker.enteredEvent.wait(), timeout=2) + async_blocker.releaseEvent.set() + result = await asyncio.wait_for(evaluation_task, timeout=2) + + assert result.value is False + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_multi_provider_isolates_provider_hooks_and_runs_lifecycle(): + first_hook = RecordingHook("first") + second_hook = RecordingHook("second") + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.FLAG_NOT_FOUND, + error_message="missing", + ), + hook_list=[first_hook], + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + hook_list=[second_hook], + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + api.set_provider(multi_provider) + client = api.get_client() + result = client.get_boolean_details( + "flagKey", + False, + evaluation_context=EvaluationContext(attributes={"base": "value"}), + ) + + assert result.value is True + assert first_hook.events == ["before", "error", "finally"] + assert second_hook.events == ["before", "after", "finally"] + assert first_provider.seenContexts[0]["base"] == "value" + assert first_provider.seenContexts[0]["hookOwner"] == "first" + assert second_provider.seenContexts[0]["base"] == "value" + assert second_provider.seenContexts[0]["hookOwner"] == "second" + + +def test_multi_provider_does_not_run_unused_provider_hooks(): + first_hook = RecordingHook("first") + second_hook = RecordingHook("second") + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + hook_list=[first_hook], + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + hook_list=[second_hook], + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + api.set_provider(multi_provider) + client = api.get_client() + result = client.get_boolean_details("flagKey", False) + + assert result.value is True + assert first_hook.events == ["before", "after", "finally"] + assert second_hook.events == [] + + +def test_multi_provider_aggregates_status_and_deduplicates_events(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + spy = MagicMock() + client.add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready) + client.add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) + client.add_handler(ProviderEvent.PROVIDER_STALE, spy.provider_stale) + spy.provider_ready.reset_mock() + + first_provider.emit_provider_stale(ProviderEventDetails(message="stale")) + assert client.get_provider_status() == ProviderStatus.STALE + assert spy.provider_stale.call_count == 1 + + second_provider.emit_provider_stale(ProviderEventDetails(message="still stale")) + assert client.get_provider_status() == ProviderStatus.STALE + assert spy.provider_stale.call_count == 1 + + first_provider.emit_provider_error( + ProviderEventDetails(error_code=ErrorCode.GENERAL, message="error") + ) + assert client.get_provider_status() == ProviderStatus.ERROR + assert spy.provider_error.call_count == 1 + + second_provider.emit_provider_error( + ProviderEventDetails(error_code=ErrorCode.PROVIDER_FATAL, message="fatal") + ) + assert client.get_provider_status() == ProviderStatus.FATAL + assert spy.provider_error.call_count == 2 + + second_provider.emit_provider_ready(ProviderEventDetails()) + assert client.get_provider_status() == ProviderStatus.ERROR + assert spy.provider_error.call_count == 3 + + first_provider.emit_provider_ready(ProviderEventDetails()) + assert client.get_provider_status() == ProviderStatus.READY + assert spy.provider_ready.call_count == 1 + + +def test_multi_provider_forwards_configuration_changed_events(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + spy = MagicMock() + client.add_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + spy.provider_configuration_changed, + ) + + first_provider.emit_provider_configuration_changed(ProviderEventDetails(message="one")) + second_provider.emit_provider_configuration_changed(ProviderEventDetails(message="two")) + + assert spy.provider_configuration_changed.call_count == 2 + + +def test_multi_provider_reports_not_ready_after_shutdown(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + + api.shutdown() + + assert client.get_provider_status() == ProviderStatus.NOT_READY