diff --git a/mvnw b/mvnw old mode 100644 new mode 100755 diff --git a/src/main/java/dev/openfeature/sdk/EventProvider.java b/src/main/java/dev/openfeature/sdk/EventProvider.java index c126c1451..a02191606 100644 --- a/src/main/java/dev/openfeature/sdk/EventProvider.java +++ b/src/main/java/dev/openfeature/sdk/EventProvider.java @@ -2,9 +2,12 @@ import dev.openfeature.sdk.internal.ConfigurableThreadFactory; import dev.openfeature.sdk.internal.TriConsumer; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import lombok.extern.slf4j.Slf4j; /** @@ -22,6 +25,7 @@ @Slf4j public abstract class EventProvider implements FeatureProvider { private EventProviderListener eventProviderListener; + private final List> eventObservers = new CopyOnWriteArrayList<>(); private final ExecutorService emitterExecutor = Executors.newCachedThreadPool(new ConfigurableThreadFactory("openfeature-event-emitter-thread", true)); @@ -54,6 +58,31 @@ void detach() { this.onEmit = null; } + /** + * Add a provider event observer. + * + *

Observers are invoked whenever this provider emits an event and are intended for advanced + * provider composition scenarios. + * + * @param observer observer callback + */ + public void addEventObserver(BiConsumer observer) { + if (observer != null) { + eventObservers.add(observer); + } + } + + /** + * Remove a previously registered provider event observer. + * + * @param observer observer callback + */ + public void removeEventObserver(BiConsumer observer) { + if (observer != null) { + eventObservers.remove(observer); + } + } + /** * Stop the event emitter executor and block until either termination has completed * or timeout period has elapsed. @@ -81,8 +110,9 @@ public void shutdown() { public Awaitable emit(final ProviderEvent event, final ProviderEventDetails details) { final var localEventProviderListener = this.eventProviderListener; final var localOnEmit = this.onEmit; + final var localEventObservers = this.eventObservers; - if (localEventProviderListener == null && localOnEmit == null) { + if (localEventProviderListener == null && localOnEmit == null && localEventObservers.isEmpty()) { return Awaitable.FINISHED; } @@ -98,6 +128,13 @@ public Awaitable emit(final ProviderEvent event, final ProviderEventDetails deta if (localOnEmit != null) { localOnEmit.accept(this, event, details); } + for (BiConsumer observer : localEventObservers) { + try { + observer.accept(event, details); + } catch (Exception e) { + log.error("Exception in provider event observer {}", observer, e); + } + } } finally { awaitable.wakeup(); } diff --git a/src/main/java/dev/openfeature/sdk/multiprovider/ComparisonStrategy.java b/src/main/java/dev/openfeature/sdk/multiprovider/ComparisonStrategy.java new file mode 100644 index 000000000..d2ba63aca --- /dev/null +++ b/src/main/java/dev/openfeature/sdk/multiprovider/ComparisonStrategy.java @@ -0,0 +1,164 @@ +package dev.openfeature.sdk.multiprovider; + +import dev.openfeature.sdk.ErrorCode; +import dev.openfeature.sdk.EvaluationContext; +import dev.openfeature.sdk.FeatureProvider; +import dev.openfeature.sdk.ProviderEvaluation; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.function.BiConsumer; +import java.util.function.Function; +import lombok.Getter; + +/** + * Comparison strategy. + * + *

Evaluates all providers and compares successful results. + */ +public class ComparisonStrategy implements Strategy { + + @Getter + private final String fallbackProvider; + + private final BiConsumer>> onMismatch; + + /** + * Constructs a comparison strategy with a fallback provider. + * + * @param fallbackProvider provider name to use as fallback when successful providers disagree + */ + public ComparisonStrategy(String fallbackProvider) { + this(fallbackProvider, null); + } + + /** + * Constructs a comparison strategy with fallback provider and mismatch callback. + * + * @param fallbackProvider provider name to use as fallback when successful providers disagree + * @param onMismatch callback invoked with all successful evaluations when they disagree + */ + public ComparisonStrategy( + String fallbackProvider, + BiConsumer>> onMismatch) { + this.fallbackProvider = Objects.requireNonNull(fallbackProvider, "fallbackProvider must not be null"); + this.onMismatch = onMismatch; + } + + @Override + public ProviderEvaluation evaluate( + Map providers, + String key, + T defaultValue, + EvaluationContext ctx, + Function> providerFunction) { + if (providers.isEmpty()) { + return ProviderEvaluation.builder() + .errorCode(ErrorCode.GENERAL) + .errorMessage("No providers configured") + .build(); + } + if (!providers.containsKey(fallbackProvider)) { + throw new IllegalArgumentException("fallbackProvider not found in providers: " + fallbackProvider); + } + + Map> successfulResults = new ConcurrentHashMap<>(providers.size()); + Map providerErrors = new ConcurrentHashMap<>(providers.size()); + ExecutorService executorService = Executors.newFixedThreadPool(providers.size()); + try { + List> tasks = new ArrayList<>(providers.size()); + for (Map.Entry entry : providers.entrySet()) { + String providerName = entry.getKey(); + FeatureProvider provider = entry.getValue(); + tasks.add(() -> { + try { + ProviderEvaluation evaluation = providerFunction.apply(provider); + if (evaluation == null) { + providerErrors.put(providerName, "null evaluation"); + } else if (evaluation.getErrorCode() == null) { + successfulResults.put(providerName, evaluation); + } else { + providerErrors.put( + providerName, + evaluation.getErrorCode() + ": " + String.valueOf(evaluation.getErrorMessage())); + } + } catch (Exception e) { + providerErrors.put(providerName, e.getClass().getSimpleName() + ": " + e.getMessage()); + } + return null; + }); + } + List> futures = executorService.invokeAll(tasks); + for (Future future : futures) { + future.get(); + } + } catch (Exception e) { + return ProviderEvaluation.builder() + .errorCode(ErrorCode.GENERAL) + .errorMessage("Comparison strategy failed: " + e.getMessage()) + .build(); + } finally { + executorService.shutdown(); + } + + if (!providerErrors.isEmpty()) { + return ProviderEvaluation.builder() + .errorCode(ErrorCode.GENERAL) + .errorMessage("Provider errors: " + buildErrorSummary(providerErrors)) + .build(); + } + + ProviderEvaluation fallbackResult = successfulResults.get(fallbackProvider); + if (fallbackResult == null) { + return ProviderEvaluation.builder() + .errorCode(ErrorCode.GENERAL) + .errorMessage("Fallback provider did not return a successful evaluation: " + fallbackProvider) + .build(); + } + + if (allEvaluationsMatch(successfulResults)) { + return fallbackResult; + } + + if (onMismatch != null) { + Map> mismatchPayload = new LinkedHashMap<>(successfulResults); + onMismatch.accept(key, Collections.unmodifiableMap(mismatchPayload)); + } + return fallbackResult; + } + + private String buildErrorSummary(Map providerErrors) { + StringBuilder builder = new StringBuilder(); + boolean first = true; + for (Map.Entry entry : providerErrors.entrySet()) { + if (!first) { + builder.append("; "); + } + first = false; + builder.append(entry.getKey()).append(" -> ").append(entry.getValue()); + } + return builder.toString(); + } + + private boolean allEvaluationsMatch(Map> results) { + ProviderEvaluation baseline = null; + for (ProviderEvaluation evaluation : results.values()) { + if (baseline == null) { + baseline = evaluation; + continue; + } + if (!Objects.equals(baseline.getValue(), evaluation.getValue())) { + return false; + } + } + return true; + } +} diff --git a/src/main/java/dev/openfeature/sdk/multiprovider/MultiProvider.java b/src/main/java/dev/openfeature/sdk/multiprovider/MultiProvider.java index cc6fb8db2..937a10458 100644 --- a/src/main/java/dev/openfeature/sdk/multiprovider/MultiProvider.java +++ b/src/main/java/dev/openfeature/sdk/multiprovider/MultiProvider.java @@ -1,11 +1,26 @@ package dev.openfeature.sdk.multiprovider; +import dev.openfeature.sdk.ClientMetadata; +import dev.openfeature.sdk.DefaultHookData; +import dev.openfeature.sdk.ErrorCode; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.EventProvider; import dev.openfeature.sdk.FeatureProvider; +import dev.openfeature.sdk.FlagEvaluationDetails; +import dev.openfeature.sdk.FlagValueType; +import dev.openfeature.sdk.Hook; +import dev.openfeature.sdk.HookContext; +import dev.openfeature.sdk.HookData; +import dev.openfeature.sdk.ImmutableContext; import dev.openfeature.sdk.Metadata; import dev.openfeature.sdk.ProviderEvaluation; +import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.ProviderEventDetails; +import dev.openfeature.sdk.ProviderState; +import dev.openfeature.sdk.TrackingEventDetails; import dev.openfeature.sdk.Value; +import dev.openfeature.sdk.exceptions.ExceptionUtils; +import dev.openfeature.sdk.exceptions.OpenFeatureError; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.util.ArrayList; import java.util.Collection; @@ -16,9 +31,12 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import lombok.Getter; import lombok.extern.slf4j.Slf4j; @@ -40,6 +58,12 @@ public class MultiProvider extends EventProvider { private final Map providers; private final Strategy strategy; + private final Map providerStates = new ConcurrentHashMap<>(); + private final Map> providerEventObservers = + new ConcurrentHashMap<>(); + private final ClientMetadata hookClientMetadata = MultiProvider::getNAME; + private final Map emptyHookHints = Collections.emptyMap(); + private ProviderState aggregateState; private MultiProviderMetadata metadata; /** @@ -61,20 +85,61 @@ public MultiProvider(List providers) { public MultiProvider(List providers, Strategy strategy) { this.providers = buildProviders(providers); this.strategy = Objects.requireNonNull(strategy, "strategy must not be null"); + initializeProviderStates(); + this.aggregateState = determineAggregateState(); } protected static Map buildProviders(List providers) { + Objects.requireNonNull(providers, "providers must not be null"); Map providersMap = new LinkedHashMap<>(providers.size()); + Map suffixesByBaseName = new HashMap<>(providers.size()); for (FeatureProvider provider : providers) { - FeatureProvider prevProvider = - providersMap.put(provider.getMetadata().getName(), provider); - if (prevProvider != null) { - log.info("duplicated provider name: {}", provider.getMetadata().getName()); + Objects.requireNonNull(provider, "provider must not be null"); + String baseName = getProviderBaseName(provider); + String resolvedName = resolveUniqueProviderName(baseName, providersMap, suffixesByBaseName); + if (!baseName.equals(resolvedName)) { + log.info("deduplicated provider name from {} to {}", baseName, resolvedName); } + providersMap.put(resolvedName, provider); } return Collections.unmodifiableMap(providersMap); } + private static String getProviderBaseName(FeatureProvider provider) { + Metadata providerMetadata = provider.getMetadata(); + if (providerMetadata == null || providerMetadata.getName() == null || providerMetadata.getName().isEmpty()) { + return "provider"; + } + return providerMetadata.getName(); + } + + private static String resolveUniqueProviderName( + String baseName, + Map providersMap, + Map suffixesByBaseName) { + if (!providersMap.containsKey(baseName)) { + suffixesByBaseName.putIfAbsent(baseName, 1); + return baseName; + } + int suffix = suffixesByBaseName.getOrDefault(baseName, 1); + String resolvedName = baseName + "-" + suffix; + while (providersMap.containsKey(resolvedName)) { + suffix++; + resolvedName = baseName + "-" + suffix; + } + suffixesByBaseName.put(baseName, suffix + 1); + return resolvedName; + } + + private void initializeProviderStates() { + providerStates.clear(); + if (!providers.isEmpty()) { + for (String providerName : providers.keySet()) { + providerStates.put(providerName, ProviderState.NOT_READY); + } + } + } + /** * Initialize the provider. * @@ -85,7 +150,11 @@ protected static Map buildProviders(List providersMetadata = new HashMap<>(); + Map providersMetadata = new LinkedHashMap<>(); + initializeProviderStates(); + synchronized (this) { + emitAggregateStateChange(determineAggregateState(), ProviderEventDetails.builder().build()); + } if (providers.isEmpty()) { metadataBuilder.originalMetadata(Collections.unmodifiableMap(providersMetadata)); @@ -96,13 +165,22 @@ public void initialize(EvaluationContext evaluationContext) throws Exception { ExecutorService executorService = Executors.newFixedThreadPool(Math.min(INIT_THREADS_COUNT, providers.size())); try { Collection> tasks = new ArrayList<>(providers.size()); - for (FeatureProvider provider : providers.values()) { + for (Map.Entry entry : providers.entrySet()) { + String providerName = entry.getKey(); + FeatureProvider provider = entry.getValue(); + registerChildProviderObserver(providerName, provider); tasks.add(() -> { - provider.initialize(evaluationContext); - return null; + try { + provider.initialize(evaluationContext); + setProviderState(providerName, ProviderState.READY, ProviderEventDetails.builder().build()); + return null; + } catch (Exception e) { + setProviderState(providerName, toStateFromException(e), providerErrorDetails(e)); + throw e; + } }); Metadata providerMetadata = provider.getMetadata(); - providersMetadata.put(providerMetadata.getName(), providerMetadata); + providersMetadata.put(providerName, providerMetadata); } metadataBuilder.originalMetadata(Collections.unmodifiableMap(providersMetadata)); @@ -138,42 +216,470 @@ public Metadata getMetadata() { @Override public ProviderEvaluation getBooleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx) { return strategy.evaluate( - providers, key, defaultValue, ctx, p -> p.getBooleanEvaluation(key, defaultValue, ctx)); + providers, + key, + defaultValue, + ctx, + provider -> evaluateWithProviderHooks( + provider, + key, + defaultValue, + ctx, + FlagValueType.BOOLEAN, + (p, evaluationContext) -> p.getBooleanEvaluation(key, defaultValue, evaluationContext))); } @Override public ProviderEvaluation getStringEvaluation(String key, String defaultValue, EvaluationContext ctx) { - return strategy.evaluate(providers, key, defaultValue, ctx, p -> p.getStringEvaluation(key, defaultValue, ctx)); + return strategy.evaluate( + providers, + key, + defaultValue, + ctx, + provider -> evaluateWithProviderHooks( + provider, + key, + defaultValue, + ctx, + FlagValueType.STRING, + (p, evaluationContext) -> p.getStringEvaluation(key, defaultValue, evaluationContext))); } @Override public ProviderEvaluation getIntegerEvaluation(String key, Integer defaultValue, EvaluationContext ctx) { return strategy.evaluate( - providers, key, defaultValue, ctx, p -> p.getIntegerEvaluation(key, defaultValue, ctx)); + providers, + key, + defaultValue, + ctx, + provider -> evaluateWithProviderHooks( + provider, + key, + defaultValue, + ctx, + FlagValueType.INTEGER, + (p, evaluationContext) -> p.getIntegerEvaluation(key, defaultValue, evaluationContext))); } @Override public ProviderEvaluation getDoubleEvaluation(String key, Double defaultValue, EvaluationContext ctx) { - return strategy.evaluate(providers, key, defaultValue, ctx, p -> p.getDoubleEvaluation(key, defaultValue, ctx)); + return strategy.evaluate( + providers, + key, + defaultValue, + ctx, + provider -> evaluateWithProviderHooks( + provider, + key, + defaultValue, + ctx, + FlagValueType.DOUBLE, + (p, evaluationContext) -> p.getDoubleEvaluation(key, defaultValue, evaluationContext))); } @Override public ProviderEvaluation getObjectEvaluation(String key, Value defaultValue, EvaluationContext ctx) { - return strategy.evaluate(providers, key, defaultValue, ctx, p -> p.getObjectEvaluation(key, defaultValue, ctx)); + return strategy.evaluate( + providers, + key, + defaultValue, + ctx, + provider -> evaluateWithProviderHooks( + provider, + key, + defaultValue, + ctx, + FlagValueType.OBJECT, + (p, evaluationContext) -> p.getObjectEvaluation(key, defaultValue, evaluationContext))); + } + + @Override + public void track(String eventName, EvaluationContext context, TrackingEventDetails details) { + for (Map.Entry entry : providers.entrySet()) { + String providerName = entry.getKey(); + FeatureProvider provider = entry.getValue(); + if (!shouldTrackProvider(providerName)) { + continue; + } + try { + provider.track(eventName, context, details); + } catch (Exception e) { + log.error("error forwarding track to provider {}", providerName, e); + } + } } @Override public void shutdown() { log.debug("shutdown begin"); - for (FeatureProvider provider : providers.values()) { + for (Map.Entry entry : providers.entrySet()) { + String providerName = entry.getKey(); + FeatureProvider provider = entry.getValue(); try { + unregisterChildProviderObserver(providerName, provider); provider.shutdown(); } catch (Exception e) { - log.error("error shutdown provider {}", provider.getMetadata().getName(), e); + log.error("error shutdown provider {}", providerName, e); } } + synchronized (this) { + initializeProviderStates(); + emitAggregateStateChange(ProviderState.NOT_READY, ProviderEventDetails.builder().build()); + } log.debug("shutdown end"); - // Important: ensure EventProvider's executor is also shut down super.shutdown(); } + + private void registerChildProviderObserver(String providerName, FeatureProvider provider) { + if (provider instanceof EventProvider) { + BiConsumer observer = + (event, details) -> onChildProviderEvent(providerName, event, details); + ((EventProvider) provider).addEventObserver(observer); + providerEventObservers.put(providerName, observer); + } + } + + private void unregisterChildProviderObserver(String providerName, FeatureProvider provider) { + if (provider instanceof EventProvider) { + BiConsumer observer = providerEventObservers.remove(providerName); + if (observer != null) { + ((EventProvider) provider).removeEventObserver(observer); + } + } + } + + private void onChildProviderEvent(String providerName, ProviderEvent event, ProviderEventDetails details) { + if (ProviderEvent.PROVIDER_CONFIGURATION_CHANGED.equals(event)) { + emitProviderConfigurationChanged(details); + return; + } + ProviderState state = toStateFromEvent(event, details); + if (state != null) { + setProviderState(providerName, state, details); + } + } + + private synchronized void setProviderState( + String providerName, + ProviderState providerState, + ProviderEventDetails details) { + providerStates.put(providerName, providerState); + ProviderState aggregate = determineAggregateState(); + emitAggregateStateChange(aggregate, details); + } + + private void emitAggregateStateChange(ProviderState aggregate, ProviderEventDetails details) { + ProviderState previous = aggregateState; + if (previous == aggregate) { + return; + } + aggregateState = aggregate; + switch (aggregate) { + case READY: + emitProviderReady(detailsOrEmpty(details)); + break; + case STALE: + emitProviderStale(detailsOrEmpty(details)); + break; + case ERROR: + emitProviderError(ensureErrorDetails(details, ErrorCode.GENERAL)); + break; + case FATAL: + emitProviderError(ensureErrorDetails(details, ErrorCode.PROVIDER_FATAL)); + break; + case NOT_READY: + break; + default: + break; + } + } + + private ProviderState determineAggregateState() { + if (providerStates.isEmpty()) { + return ProviderState.READY; + } + ProviderState aggregate = ProviderState.READY; + for (ProviderState state : providerStates.values()) { + if (stateSeverity(state) > stateSeverity(aggregate)) { + aggregate = state; + } + } + return aggregate; + } + + private int stateSeverity(ProviderState state) { + if (state == null) { + return 0; + } + switch (state) { + case FATAL: + return 5; + case NOT_READY: + return 4; + case ERROR: + return 3; + case STALE: + return 2; + case READY: + return 1; + default: + return 0; + } + } + + private ProviderEventDetails detailsOrEmpty(ProviderEventDetails details) { + if (details == null) { + return ProviderEventDetails.builder().build(); + } + return details; + } + + private ProviderEventDetails ensureErrorDetails(ProviderEventDetails details, ErrorCode defaultErrorCode) { + if (details == null) { + return ProviderEventDetails.builder().errorCode(defaultErrorCode).build(); + } + if (details.getErrorCode() == null) { + return details.toBuilder().errorCode(defaultErrorCode).build(); + } + return details; + } + + private ProviderState toStateFromEvent(ProviderEvent event, ProviderEventDetails details) { + if (ProviderEvent.PROVIDER_READY.equals(event)) { + return ProviderState.READY; + } + if (ProviderEvent.PROVIDER_STALE.equals(event)) { + return ProviderState.STALE; + } + if (ProviderEvent.PROVIDER_ERROR.equals(event)) { + if (details != null && ErrorCode.PROVIDER_FATAL.equals(details.getErrorCode())) { + return ProviderState.FATAL; + } + return ProviderState.ERROR; + } + return null; + } + + private ProviderState toStateFromException(Exception exception) { + if (exception instanceof OpenFeatureError + && ErrorCode.PROVIDER_FATAL.equals(((OpenFeatureError) exception).getErrorCode())) { + return ProviderState.FATAL; + } + return ProviderState.ERROR; + } + + private ProviderEventDetails providerErrorDetails(Exception exception) { + if (exception instanceof OpenFeatureError) { + ErrorCode errorCode = ((OpenFeatureError) exception).getErrorCode(); + return ProviderEventDetails.builder() + .errorCode(errorCode) + .message(exception.getMessage()) + .build(); + } + return ProviderEventDetails.builder() + .errorCode(ErrorCode.GENERAL) + .message(exception.getMessage()) + .build(); + } + + private boolean shouldTrackProvider(String providerName) { + ProviderState providerState = providerStates.getOrDefault(providerName, ProviderState.READY); + return !ProviderState.NOT_READY.equals(providerState) && !ProviderState.FATAL.equals(providerState); + } + + private EvaluationContext copyEvaluationContext(EvaluationContext context) { + if (context == null) { + return ImmutableContext.EMPTY; + } + String targetingKey = context.getTargetingKey(); + if (targetingKey == null) { + return new ImmutableContext(context.asMap()); + } + return new ImmutableContext(targetingKey, context.asMap()); + } + + private EvaluationContext toProviderContext(EvaluationContext originalContext, EvaluationContext evaluatedContext) { + if (originalContext == null && (evaluatedContext == null || evaluatedContext.isEmpty())) { + return null; + } + return evaluatedContext; + } + + private Exception toEvaluationException(ProviderEvaluation providerEvaluation) { + if (providerEvaluation == null || providerEvaluation.getErrorCode() == null) { + return new RuntimeException("Provider evaluation returned an error"); + } + return ExceptionUtils.instantiateErrorByErrorCode( + providerEvaluation.getErrorCode(), + providerEvaluation.getErrorMessage()); + } + + private HookContext createHookContext( + String key, + FlagValueType valueType, + T defaultValue, + EvaluationContext evaluationContext, + FeatureProvider provider, + HookData hookData) { + return HookContext.builder() + .flagKey(key) + .type(valueType) + .defaultValue(normalizeDefaultValue(valueType, defaultValue)) + .ctx(evaluationContext) + .clientMetadata(hookClientMetadata) + .providerMetadata(provider.getMetadata()) + .hookData(hookData) + .build(); + } + + @SuppressWarnings("unchecked") + private T normalizeDefaultValue(FlagValueType valueType, T defaultValue) { + if (defaultValue != null) { + return defaultValue; + } + switch (valueType) { + case BOOLEAN: + return (T) Boolean.FALSE; + case STRING: + return (T) ""; + case INTEGER: + return (T) Integer.valueOf(0); + case DOUBLE: + return (T) Double.valueOf(0d); + case OBJECT: + return (T) new Value(); + default: + return defaultValue; + } + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private ProviderEvaluation evaluateWithProviderHooks( + FeatureProvider provider, + String key, + T defaultValue, + EvaluationContext ctx, + FlagValueType valueType, + BiFunction> providerFunction) { + List providerHooks = provider.getProviderHooks(); + if (providerHooks == null || providerHooks.isEmpty()) { + return providerFunction.apply(provider, ctx); + } + + List> hooks = new ArrayList<>(providerHooks.size()); + for (Hook hook : providerHooks) { + if (hook.supportsFlagValueType(valueType)) { + hooks.add(new HookExecution<>(hook, new DefaultHookData())); + } + } + + if (hooks.isEmpty()) { + return providerFunction.apply(provider, ctx); + } + + EvaluationContext evaluatedContext = copyEvaluationContext(ctx); + ProviderEvaluation providerEvaluation = null; + FlagEvaluationDetails details = null; + + try { + for (int i = hooks.size() - 1; i >= 0; i--) { + HookExecution execution = hooks.get(i); + HookContext hookContext = + createHookContext(key, valueType, defaultValue, evaluatedContext, provider, execution.hookData); + var contextUpdate = execution.hook.before(hookContext, emptyHookHints); + if (contextUpdate != null + && contextUpdate.isPresent() + && contextUpdate.get() != hookContext.getCtx() + && !contextUpdate.get().isEmpty()) { + evaluatedContext = evaluatedContext.merge(contextUpdate.get()); + } + } + + providerEvaluation = providerFunction.apply(provider, toProviderContext(ctx, evaluatedContext)); + details = FlagEvaluationDetails.from(providerEvaluation, key); + + if (providerEvaluation.getErrorCode() == null) { + for (HookExecution execution : hooks) { + execution.hook.after( + createHookContext( + key, + valueType, + defaultValue, + evaluatedContext, + provider, + execution.hookData), + details, + emptyHookHints); + } + } else { + Exception providerException = toEvaluationException(providerEvaluation); + for (HookExecution execution : hooks) { + try { + execution.hook.error( + createHookContext( + key, + valueType, + defaultValue, + evaluatedContext, + provider, + execution.hookData), + providerException, + emptyHookHints); + } catch (Exception e) { + log.error("error executing provider hook error stage", e); + } + } + } + + return providerEvaluation; + } catch (Exception e) { + for (HookExecution execution : hooks) { + try { + execution.hook.error( + createHookContext( + key, + valueType, + defaultValue, + evaluatedContext, + provider, + execution.hookData), + e, + emptyHookHints); + } catch (Exception hookError) { + log.error("error executing provider hook error stage", hookError); + } + } + throw e; + } finally { + FlagEvaluationDetails finalDetails = details == null + ? FlagEvaluationDetails.builder().flagKey(key).value(defaultValue).build() + : details; + for (HookExecution execution : hooks) { + try { + execution.hook.finallyAfter( + createHookContext( + key, + valueType, + defaultValue, + evaluatedContext, + provider, + execution.hookData), + finalDetails, + emptyHookHints); + } catch (Exception e) { + log.error("error executing provider hook finally stage", e); + } + } + } + } + + private static final class HookExecution { + private final Hook hook; + private final HookData hookData; + + private HookExecution(Hook hook, HookData hookData) { + this.hook = hook; + this.hookData = hookData; + } + } } diff --git a/src/test/java/dev/openfeature/sdk/multiprovider/ComparisonStrategyTest.java b/src/test/java/dev/openfeature/sdk/multiprovider/ComparisonStrategyTest.java new file mode 100644 index 000000000..7938e2b95 --- /dev/null +++ b/src/test/java/dev/openfeature/sdk/multiprovider/ComparisonStrategyTest.java @@ -0,0 +1,103 @@ +package dev.openfeature.sdk.multiprovider; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.openfeature.sdk.ErrorCode; +import dev.openfeature.sdk.FeatureProvider; +import dev.openfeature.sdk.ProviderEvaluation; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class ComparisonStrategyTest extends BaseStrategyTest { + + @Test + void shouldReturnFallbackResultWhenAllProvidersAgree() { + setupProviderSuccess(mockProvider1, "same"); + setupProviderSuccess(mockProvider2, "same"); + + Map providers = new LinkedHashMap<>(); + providers.put("provider1", mockProvider1); + providers.put("provider2", mockProvider2); + + ComparisonStrategy strategy = new ComparisonStrategy("provider2"); + ProviderEvaluation result = strategy.evaluate( + providers, + FLAG_KEY, + DEFAULT_STRING, + null, + p -> p.getStringEvaluation(FLAG_KEY, DEFAULT_STRING, null)); + + assertNotNull(result); + assertEquals("same", result.getValue()); + assertNull(result.getErrorCode()); + } + + @Test + void shouldCallMismatchCallbackAndReturnFallbackResult() { + setupProviderSuccess(mockProvider1, "first"); + setupProviderSuccess(mockProvider2, "second"); + + Map providers = new LinkedHashMap<>(); + providers.put("provider1", mockProvider1); + providers.put("provider2", mockProvider2); + + AtomicInteger callbackCount = new AtomicInteger(); + ComparisonStrategy strategy = new ComparisonStrategy("provider2", (key, evaluations) -> callbackCount.incrementAndGet()); + + ProviderEvaluation result = strategy.evaluate( + providers, + FLAG_KEY, + DEFAULT_STRING, + null, + p -> p.getStringEvaluation(FLAG_KEY, DEFAULT_STRING, null)); + + assertEquals("second", result.getValue()); + assertNull(result.getErrorCode()); + assertEquals(1, callbackCount.get()); + } + + @Test + void shouldReturnGeneralErrorWhenAnyProviderFails() { + setupProviderSuccess(mockProvider1, "ok"); + setupProviderError(mockProvider2, ErrorCode.PARSE_ERROR); + + Map providers = new LinkedHashMap<>(); + providers.put("provider1", mockProvider1); + providers.put("provider2", mockProvider2); + + ComparisonStrategy strategy = new ComparisonStrategy("provider1"); + ProviderEvaluation result = strategy.evaluate( + providers, + FLAG_KEY, + DEFAULT_STRING, + null, + p -> p.getStringEvaluation(FLAG_KEY, DEFAULT_STRING, null)); + + assertEquals(ErrorCode.GENERAL, result.getErrorCode()); + assertTrue(result.getErrorMessage().contains("provider2")); + } + + @Test + void shouldThrowWhenFallbackProviderIsMissing() { + setupProviderSuccess(mockProvider1, "ok"); + + Map providers = new LinkedHashMap<>(); + providers.put("provider1", mockProvider1); + + ComparisonStrategy strategy = new ComparisonStrategy("provider2"); + assertThrows( + IllegalArgumentException.class, + () -> strategy.evaluate( + providers, + FLAG_KEY, + DEFAULT_STRING, + null, + p -> p.getStringEvaluation(FLAG_KEY, DEFAULT_STRING, null))); + } +} diff --git a/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderEventsAndTrackingTest.java b/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderEventsAndTrackingTest.java new file mode 100644 index 000000000..42a3f5027 --- /dev/null +++ b/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderEventsAndTrackingTest.java @@ -0,0 +1,145 @@ +package dev.openfeature.sdk.multiprovider; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import dev.openfeature.sdk.Client; +import dev.openfeature.sdk.EvaluationContext; +import dev.openfeature.sdk.EventProvider; +import dev.openfeature.sdk.Metadata; +import dev.openfeature.sdk.OpenFeatureAPI; +import dev.openfeature.sdk.ProviderEvaluation; +import dev.openfeature.sdk.ProviderEventDetails; +import dev.openfeature.sdk.ProviderState; +import dev.openfeature.sdk.TrackingEventDetails; +import dev.openfeature.sdk.Value; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class MultiProviderEventsAndTrackingTest { + + @Test + void shouldAggregateChildProviderStateAndForwardConfigurationEvents() throws Exception { + TrackingProvider provider1 = new TrackingProvider("provider1"); + TrackingProvider provider2 = new TrackingProvider("provider2"); + MultiProvider multiProvider = new MultiProvider(List.of(provider1, provider2)); + + OpenFeatureAPI api = new TestOpenFeatureAPI(); + api.shutdown(); + try { + api.setProviderAndWait("multiProviderEvents", multiProvider); + Client client = api.getClient("multiProviderEvents"); + + await().atMost(Duration.ofSeconds(2)).until(() -> client.getProviderState() == ProviderState.READY); + + AtomicInteger configurationChangedCount = new AtomicInteger(); + client.onProviderConfigurationChanged(details -> configurationChangedCount.incrementAndGet()); + + provider1.emitProviderConfigurationChanged(ProviderEventDetails.builder().message("changed").build()).await(); + await().atMost(Duration.ofSeconds(2)).until(() -> configurationChangedCount.get() == 1); + + provider1.emitProviderStale(ProviderEventDetails.builder().message("stale").build()).await(); + await().atMost(Duration.ofSeconds(2)).until(() -> client.getProviderState() == ProviderState.STALE); + + provider2.emitProviderError( + ProviderEventDetails.builder().errorCode(dev.openfeature.sdk.ErrorCode.GENERAL).build()) + .await(); + await().atMost(Duration.ofSeconds(2)).until(() -> client.getProviderState() == ProviderState.ERROR); + + provider2.emitProviderReady(ProviderEventDetails.builder().build()).await(); + await().atMost(Duration.ofSeconds(2)).until(() -> client.getProviderState() == ProviderState.STALE); + + provider1.emitProviderReady(ProviderEventDetails.builder().build()).await(); + await().atMost(Duration.ofSeconds(2)).until(() -> client.getProviderState() == ProviderState.READY); + + provider1.emitProviderError( + ProviderEventDetails.builder() + .errorCode(dev.openfeature.sdk.ErrorCode.PROVIDER_FATAL) + .build()) + .await(); + await().atMost(Duration.ofSeconds(2)).until(() -> client.getProviderState() == ProviderState.FATAL); + } finally { + api.shutdown(); + } + } + + @Test + void shouldForwardTrackToReadyProvidersAndSkipFatalProviders() throws Exception { + TrackingProvider provider1 = new TrackingProvider("provider1"); + TrackingProvider provider2 = new TrackingProvider("provider2"); + provider2.throwOnTrack = true; + + MultiProvider multiProvider = new MultiProvider(List.of(provider1, provider2)); + multiProvider.initialize(null); + + multiProvider.track("event1", null, null); + assertEquals(1, provider1.trackCount.get()); + assertEquals(1, provider2.trackCount.get()); + + provider1.emitProviderError( + ProviderEventDetails.builder() + .errorCode(dev.openfeature.sdk.ErrorCode.PROVIDER_FATAL) + .build()) + .await(); + + multiProvider.track("event2", null, null); + assertEquals(1, provider1.trackCount.get()); + assertEquals(2, provider2.trackCount.get()); + } + + static class TrackingProvider extends EventProvider { + private final String name; + private final AtomicInteger trackCount = new AtomicInteger(); + private boolean throwOnTrack; + + TrackingProvider(String name) { + this.name = name; + } + + @Override + public Metadata getMetadata() { + return () -> name; + } + + @Override + public void track(String eventName, EvaluationContext context, TrackingEventDetails details) { + trackCount.incrementAndGet(); + if (throwOnTrack) { + throw new RuntimeException("track failure"); + } + } + + @Override + public ProviderEvaluation getBooleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(Boolean.TRUE).build(); + } + + @Override + public ProviderEvaluation getStringEvaluation(String key, String defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value("value").build(); + } + + @Override + public ProviderEvaluation getIntegerEvaluation(String key, Integer defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(1).build(); + } + + @Override + public ProviderEvaluation getDoubleEvaluation(String key, Double defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(1d).build(); + } + + @Override + public ProviderEvaluation getObjectEvaluation(String key, Value defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(new Value("value")).build(); + } + } + + static class TestOpenFeatureAPI extends OpenFeatureAPI { + TestOpenFeatureAPI() { + super(); + } + } +} diff --git a/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderHooksTest.java b/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderHooksTest.java new file mode 100644 index 000000000..14a5001eb --- /dev/null +++ b/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderHooksTest.java @@ -0,0 +1,157 @@ +package dev.openfeature.sdk.multiprovider; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import dev.openfeature.sdk.EvaluationContext; +import dev.openfeature.sdk.EventProvider; +import dev.openfeature.sdk.Hook; +import dev.openfeature.sdk.HookContext; +import dev.openfeature.sdk.Metadata; +import dev.openfeature.sdk.MutableContext; +import dev.openfeature.sdk.ProviderEvaluation; +import dev.openfeature.sdk.Value; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class MultiProviderHooksTest { + + @Test + void shouldExecuteProviderHooksAndKeepPerProviderContextIsolation() throws Exception { + RecordingHook firstHook = new RecordingHook("provider1"); + RecordingHook secondHook = new RecordingHook("provider2"); + + HookedStringProvider provider1 = new HookedStringProvider( + "provider1", + List.of(firstHook), + ProviderEvaluation.builder() + .errorCode(dev.openfeature.sdk.ErrorCode.GENERAL) + .errorMessage("failed") + .build()); + HookedStringProvider provider2 = new HookedStringProvider( + "provider2", + List.of(secondHook), + ProviderEvaluation.builder().value("ok").build()); + + MultiProvider multiProvider = new MultiProvider(List.of(provider1, provider2), new FirstSuccessfulStrategy()); + multiProvider.initialize(null); + + ProviderEvaluation evaluation = multiProvider.getStringEvaluation("flag", "default", null); + + assertEquals("ok", evaluation.getValue()); + + assertEquals(1, firstHook.beforeCount.get()); + assertEquals(0, firstHook.afterCount.get()); + assertEquals(1, firstHook.errorCount.get()); + assertEquals(1, firstHook.finallyCount.get()); + + assertEquals(1, secondHook.beforeCount.get()); + assertEquals(1, secondHook.afterCount.get()); + assertEquals(0, secondHook.errorCount.get()); + assertEquals(1, secondHook.finallyCount.get()); + + assertEquals("provider1", provider1.lastEvaluationContext.getValue("hookOwner").asString()); + assertNull(provider1.lastEvaluationContext.getValue("provider2Marker")); + + assertEquals("provider2", provider2.lastEvaluationContext.getValue("hookOwner").asString()); + assertNull(provider2.lastEvaluationContext.getValue("provider1Marker")); + } + + static class RecordingHook implements Hook { + private final String providerName; + private final AtomicInteger beforeCount = new AtomicInteger(); + private final AtomicInteger afterCount = new AtomicInteger(); + private final AtomicInteger errorCount = new AtomicInteger(); + private final AtomicInteger finallyCount = new AtomicInteger(); + + RecordingHook(String providerName) { + this.providerName = providerName; + } + + @Override + public Optional before(HookContext ctx, Map hints) { + beforeCount.incrementAndGet(); + ctx.getHookData().set("provider", providerName); + return Optional.of(new MutableContext() + .add("hookOwner", providerName) + .add(providerName + "Marker", providerName)); + } + + @Override + public void after( + HookContext ctx, + dev.openfeature.sdk.FlagEvaluationDetails details, + Map hints) { + afterCount.incrementAndGet(); + assertEquals(providerName, ctx.getHookData().get("provider")); + } + + @Override + public void error(HookContext ctx, Exception error, Map hints) { + errorCount.incrementAndGet(); + assertEquals(providerName, ctx.getHookData().get("provider")); + } + + @Override + public void finallyAfter( + HookContext ctx, + dev.openfeature.sdk.FlagEvaluationDetails details, + Map hints) { + finallyCount.incrementAndGet(); + assertEquals(providerName, ctx.getHookData().get("provider")); + } + } + + static class HookedStringProvider extends EventProvider { + private final String name; + private final List hooks; + private final ProviderEvaluation evaluation; + private EvaluationContext lastEvaluationContext; + + HookedStringProvider(String name, List hooks, ProviderEvaluation evaluation) { + this.name = name; + this.hooks = hooks; + this.evaluation = evaluation; + } + + @Override + public Metadata getMetadata() { + return () -> name; + } + + @Override + public List getProviderHooks() { + return hooks; + } + + @Override + public ProviderEvaluation getBooleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(defaultValue).build(); + } + + @Override + public ProviderEvaluation getStringEvaluation(String key, String defaultValue, EvaluationContext ctx) { + lastEvaluationContext = ctx == null ? new MutableContext() : ctx; + return evaluation; + } + + @Override + public ProviderEvaluation getIntegerEvaluation(String key, Integer defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(defaultValue).build(); + } + + @Override + public ProviderEvaluation getDoubleEvaluation(String key, Double defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(defaultValue).build(); + } + + @Override + public ProviderEvaluation getObjectEvaluation(String key, Value defaultValue, EvaluationContext ctx) { + return ProviderEvaluation.builder().value(defaultValue).build(); + } + } +} diff --git a/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderTest.java b/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderTest.java index 887b71d0a..76c436cdb 100644 --- a/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderTest.java +++ b/src/test/java/dev/openfeature/sdk/multiprovider/MultiProviderTest.java @@ -73,7 +73,12 @@ void shouldHandleDuplicateProviderNames() { List providers = new ArrayList<>(2); providers.add(mockProvider1); providers.add(mockProvider2); - assertDoesNotThrow(() -> new MultiProvider(providers).initialize(null)); + MultiProvider multiProvider = new MultiProvider(providers); + assertDoesNotThrow(() -> multiProvider.initialize(null)); + MultiProviderMetadata metadata = (MultiProviderMetadata) multiProvider.getMetadata(); + assertEquals(2, metadata.getOriginalMetadata().size()); + assertNotNull(metadata.getOriginalMetadata().get("provider")); + assertNotNull(metadata.getOriginalMetadata().get("provider-1")); } @Test