Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions openfeature/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,11 @@ def _establish_hooks_and_provider(

client_metadata = self.get_metadata()
provider_metadata = provider.get_metadata()
provider_hooks = (
[]
if self._provider_uses_internal_hooks(provider)
else provider.get_provider_hooks()
)

# Hooks need to be handled in different orders at different stages
# in the flag evaluation
Expand All @@ -450,7 +455,7 @@ def _establish_hooks_and_provider(
get_hooks(),
self.hooks,
evaluation_hooks,
provider.get_provider_hooks(),
provider_hooks,
)
]
# after, error, finally: Provider, Invocation, Client, API
Expand All @@ -465,6 +470,36 @@ def _establish_hooks_and_provider(
merged_eval_context,
)

def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool:
uses_internal_hooks = getattr(provider, "uses_internal_provider_hooks", None)
return bool(callable(uses_internal_hooks) and uses_internal_hooks())

def _set_internal_provider_hook_runtime(
self,
provider: FeatureProvider,
flag_type: FlagType,
hook_hints: HookHints,
) -> object | None:
if not self._provider_uses_internal_hooks(provider):
return None
set_hook_runtime = getattr(provider, "set_internal_provider_hook_runtime", None)
if not callable(set_hook_runtime):
return None
return set_hook_runtime(
flag_type=flag_type,
client_metadata=self.get_metadata(),
hook_hints=hook_hints,
)

def _reset_internal_provider_hook_runtime(
self, provider: FeatureProvider, runtime_token: object | None
) -> None:
if runtime_token is None:
return
reset_hook_runtime = getattr(provider, "reset_internal_provider_hook_runtime", None)
if callable(reset_hook_runtime):
reset_hook_runtime(runtime_token)

def _assert_provider_status(
self,
) -> OpenFeatureError | None:
Expand Down Expand Up @@ -611,13 +646,21 @@ async def evaluate_flag_details_async(
merged_eval_context,
)

flag_evaluation = await self._create_provider_evaluation_async(
runtime_token = self._set_internal_provider_hook_runtime(
provider,
flag_type,
flag_key,
default_value,
merged_context,
hook_hints,
)
try:
flag_evaluation = await self._create_provider_evaluation_async(
provider,
flag_type,
flag_key,
default_value,
merged_context,
)
finally:
self._reset_internal_provider_hook_runtime(provider, runtime_token)
if err := flag_evaluation.get_exception():
error_hooks(
flag_type, err, reversed_merged_hooks_and_context, hook_hints
Expand Down Expand Up @@ -787,13 +830,21 @@ def evaluate_flag_details(
merged_eval_context,
)

flag_evaluation = self._create_provider_evaluation(
runtime_token = self._set_internal_provider_hook_runtime(
provider,
flag_type,
flag_key,
default_value,
merged_context,
hook_hints,
)
try:
flag_evaluation = self._create_provider_evaluation(
provider,
flag_type,
flag_key,
default_value,
merged_context,
)
finally:
self._reset_internal_provider_hook_runtime(provider, runtime_token)
if err := flag_evaluation.get_exception():
error_hooks(
flag_type, err, reversed_merged_hooks_and_context, hook_hints
Expand Down
23 changes: 22 additions & 1 deletion openfeature/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@
if typing.TYPE_CHECKING:
from openfeature.flag_evaluation import FlagValueType

__all__ = ["AbstractProvider", "FeatureProvider", "Metadata", "ProviderStatus"]
__all__ = [
"AbstractProvider",
"ComparisonStrategy",
"EvaluationStrategy",
"FeatureProvider",
"FirstMatchStrategy",
"FirstSuccessfulStrategy",
"Metadata",
"MultiProvider",
"ProviderEntry",
"ProviderStatus",
]


class ProviderStatus(Enum):
Expand Down Expand Up @@ -247,3 +258,13 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None:
def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None:
if hasattr(self, "_on_emit"):
self._on_emit(self, event, details)


from .multi_provider import ( # noqa: E402
ComparisonStrategy,
EvaluationStrategy,
FirstMatchStrategy,
FirstSuccessfulStrategy,
MultiProvider,
ProviderEntry,
)
29 changes: 18 additions & 11 deletions openfeature/provider/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,25 @@ def _initialize_provider(self, provider: FeatureProvider) -> None:
try:
if hasattr(provider, "initialize"):
provider.initialize(self._get_evaluation_context())
self.dispatch_event(
provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails()
)
if self.get_provider_status(provider) == ProviderStatus.NOT_READY:
self.dispatch_event(
provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails()
)
except Exception as err:
error_code = (
err.error_code
if isinstance(err, OpenFeatureError)
else ErrorCode.GENERAL
)
self.dispatch_event(
provider,
ProviderEvent.PROVIDER_ERROR,
ProviderEventDetails(
message=f"Provider initialization failed: {err}",
error_code=error_code,
),
)
if self.get_provider_status(provider) == ProviderStatus.NOT_READY:
self.dispatch_event(
provider,
ProviderEvent.PROVIDER_ERROR,
ProviderEventDetails(
message=f"Provider initialization failed: {err}",
error_code=error_code,
),
)

def _shutdown_provider(self, provider: FeatureProvider) -> None:
try:
Expand All @@ -115,6 +117,11 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None:
provider.detach()

def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus:
provider_status_getter = getattr(provider, "get_status", None)
if callable(provider_status_getter):
status = provider_status_getter()
if isinstance(status, ProviderStatus):
return status
return self._provider_status.get(provider, ProviderStatus.NOT_READY)

def dispatch_event(
Expand Down
Loading
Loading