diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java index 1a3e1d352..f88cfdd13 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java @@ -1,7 +1,10 @@ package dev.openfeature.contrib.providers.flagd; import dev.openfeature.contrib.providers.flagd.resolver.rpc.cache.CacheType; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; /** Helper class to hold configuration default values. */ @@ -37,6 +40,7 @@ public final class Config { static final String FLAGD_RETRY_BACKOFF_MAX_MS_VAR_NAME = "FLAGD_RETRY_BACKOFF_MAX_MS"; static final String STREAM_DEADLINE_MS_ENV_VAR_NAME = "FLAGD_STREAM_DEADLINE_MS"; static final String SOURCE_SELECTOR_ENV_VAR_NAME = "FLAGD_SOURCE_SELECTOR"; + static final String FATAL_STATUS_CODES_ENV_VAR_NAME = "FLAGD_FATAL_STATUS_CODES"; /** * Environment variable to fetch Provider id. * @@ -93,6 +97,18 @@ static long fallBackToEnvOrDefault(String key, long defaultValue) { } } + static List fallBackToEnvOrDefaultList(String key, List defaultValue) { + try { + return System.getenv(key) != null + ? Arrays.stream(System.getenv(key).split(",")) + .map(String::trim) + .collect(Collectors.toList()) + : defaultValue; + } catch (Exception e) { + return defaultValue; + } + } + static Resolver fromValueProvider(Function provider) { final String resolverVar = provider.apply(RESOLVER_ENV_VAR); if (resolverVar == null) { diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java index 0eebe16c9..5fffe6fa2 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java @@ -1,6 +1,7 @@ package dev.openfeature.contrib.providers.flagd; import static dev.openfeature.contrib.providers.flagd.Config.fallBackToEnvOrDefault; +import static dev.openfeature.contrib.providers.flagd.Config.fallBackToEnvOrDefaultList; import static dev.openfeature.contrib.providers.flagd.Config.fromValueProvider; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.connector.QueueSource; @@ -122,6 +123,15 @@ public class FlagdOptions { @Builder.Default private int retryGracePeriod = fallBackToEnvOrDefault(Config.STREAM_RETRY_GRACE_PERIOD, Config.DEFAULT_STREAM_RETRY_GRACE_PERIOD); + + /** + * List of grpc response status codes for which the provider transitions into fatal state upon first connection. + * Defaults to empty list + */ + @Builder.Default + private List fatalStatusCodes = + fallBackToEnvOrDefaultList(Config.FATAL_STATUS_CODES_ENV_VAR_NAME, List.of()); + /** * Selector to be used with flag sync gRPC contract. * diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java index caf864175..4758e37c7 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java @@ -1,10 +1,10 @@ package dev.openfeature.contrib.providers.flagd; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; -import dev.openfeature.contrib.providers.flagd.resolver.common.FlagdProviderEvent; import dev.openfeature.contrib.providers.flagd.resolver.process.InProcessResolver; import dev.openfeature.contrib.providers.flagd.resolver.rpc.RpcResolver; import dev.openfeature.contrib.providers.flagd.resolver.rpc.cache.Cache; +import dev.openfeature.sdk.ErrorCode; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.EventProvider; import dev.openfeature.sdk.Hook; @@ -192,8 +192,9 @@ EvaluationContext getEnrichedContext() { } @SuppressWarnings("checkstyle:fallthrough") - private void onProviderEvent(FlagdProviderEvent flagdProviderEvent) { - log.debug("FlagdProviderEvent event {} ", flagdProviderEvent.getEvent()); + private void onProviderEvent( + ProviderEvent providerEvent, ProviderEventDetails providerEventDetails, Structure syncMetadata) { + log.debug("FlagdProviderEvent event {} ", providerEvent); synchronized (syncResources) { /* * We only use Error and Ready as previous states. @@ -204,10 +205,10 @@ private void onProviderEvent(FlagdProviderEvent flagdProviderEvent) { * forward a configuration changed to the ready, if we are not in the ready * state. */ - switch (flagdProviderEvent.getEvent()) { + switch (providerEvent) { case PROVIDER_CONFIGURATION_CHANGED: if (syncResources.getPreviousEvent() == ProviderEvent.PROVIDER_READY) { - onConfigurationChanged(flagdProviderEvent); + emit(providerEvent, providerEventDetails); break; } // intentional fall through @@ -216,33 +217,30 @@ private void onProviderEvent(FlagdProviderEvent flagdProviderEvent) { * Sync metadata is used to enrich the context, and is immutable in flagd, * so we only need it to be fetched once at READY. */ - if (flagdProviderEvent.getSyncMetadata() != null) { - syncResources.setEnrichedContext(contextEnricher.apply(flagdProviderEvent.getSyncMetadata())); + if (syncMetadata != null) { + syncResources.setEnrichedContext(contextEnricher.apply(syncMetadata)); } onReady(); syncResources.setPreviousEvent(ProviderEvent.PROVIDER_READY); break; - case PROVIDER_ERROR: + if (providerEventDetails != null + && providerEventDetails.getErrorCode() == ErrorCode.PROVIDER_FATAL) { + onFatal(); + break; + } + if (syncResources.getPreviousEvent() != ProviderEvent.PROVIDER_ERROR) { onError(); syncResources.setPreviousEvent(ProviderEvent.PROVIDER_ERROR); } break; - default: - log.warn("Unknown event {}", flagdProviderEvent.getEvent()); + log.warn("Unknown event {}", providerEvent); } } } - private void onConfigurationChanged(FlagdProviderEvent flagdProviderEvent) { - this.emitProviderConfigurationChanged(ProviderEventDetails.builder() - .flagsChanged(flagdProviderEvent.getFlagsChanged()) - .message("configuration changed") - .build()); - } - private void onReady() { if (syncResources.initialize()) { log.info("Initialized FlagdProvider"); @@ -284,4 +282,17 @@ private void onError() { TimeUnit.SECONDS); } } + + private void onFatal() { + if (errorTask != null && !errorTask.isCancelled()) { + errorTask.cancel(false); + } + this.syncResources.setFatal(true); + + this.emitProviderError(ProviderEventDetails.builder() + .errorCode(ErrorCode.PROVIDER_FATAL) + .build()); + + shutdown(); + } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResources.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResources.java index 03d444528..e173cc5df 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResources.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResources.java @@ -3,6 +3,7 @@ import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.ImmutableContext; import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.exceptions.FatalError; import dev.openfeature.sdk.exceptions.GeneralError; import lombok.Getter; import lombok.Setter; @@ -16,8 +17,11 @@ class FlagdProviderSyncResources { @Setter private volatile ProviderEvent previousEvent = null; + @Setter + private volatile boolean isFatal; + private volatile EvaluationContext enrichedContext = new ImmutableContext(); - private volatile boolean initialized; + private volatile boolean isInitialized; private volatile boolean isShutDown; public void setEnrichedContext(EvaluationContext context) { @@ -31,32 +35,40 @@ public void setEnrichedContext(EvaluationContext context) { * @return true iff this was the first call to {@code initialize()} */ public synchronized boolean initialize() { - if (this.initialized) { + if (this.isInitialized) { return false; } - this.initialized = true; + this.isInitialized = true; + this.isFatal = false; this.notifyAll(); return true; } /** - * Blocks the calling thread until either {@link FlagdProviderSyncResources#initialize()} or - * {@link FlagdProviderSyncResources#shutdown()} is called or the deadline is exceeded, whatever happens first. If - * {@link FlagdProviderSyncResources#initialize()} has been executed before {@code waitForInitialization(long)} is - * called, it will return instantly. If the deadline is exceeded, a GeneralError will be thrown. - * If {@link FlagdProviderSyncResources#shutdown()} is called in the meantime, an {@link IllegalStateException} will + * Blocks the calling thread until either + * {@link FlagdProviderSyncResources#initialize()} or + * {@link FlagdProviderSyncResources#shutdown()} is called or the deadline is + * exceeded, whatever happens first. If + * {@link FlagdProviderSyncResources#initialize()} has been executed before + * {@code waitForInitialization(long)} is + * called, it will return instantly. If the deadline is exceeded, a GeneralError + * will be thrown. + * If {@link FlagdProviderSyncResources#shutdown()} is called in the meantime, + * an {@link IllegalStateException} will * be thrown. Otherwise, the method will return cleanly. * * @param deadline the maximum time in ms to wait - * @throws GeneralError when the deadline is exceeded before - * {@link FlagdProviderSyncResources#initialize()} is called on this object - * @throws IllegalStateException when {@link FlagdProviderSyncResources#shutdown()} is called or has been called on - * this object + * @throws GeneralError when the deadline is exceeded before + * {@link FlagdProviderSyncResources#initialize()} is + * called on this object, or when + * {@link FlagdProviderSyncResources#shutdown()} + * @throws FatalError when the provider has been marked as fatal during + * shutdown */ public void waitForInitialization(long deadline) { long start = System.currentTimeMillis(); long end = start + deadline; - while (!initialized && !isShutDown) { + while (!isInitialized && !isShutDown) { long now = System.currentTimeMillis(); // if wait(0) is called, the thread would wait forever, so we abort when this would happen if (now >= end) { @@ -68,7 +80,7 @@ public void waitForInitialization(long deadline) { if (isShutDown) { break; } - if (initialized) { // might have changed in the meantime + if (isInitialized) { // might have changed in the meantime return; } try { @@ -80,7 +92,11 @@ public void waitForInitialization(long deadline) { } } if (isShutDown) { - throw new IllegalStateException("Already shut down"); + String msg = "Already shut down due to previous error."; + if (isFatal) { + throw new FatalError(msg); + } + throw new GeneralError(msg); } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java index c898aef3a..a8d0b901f 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java @@ -4,7 +4,6 @@ import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; -import dev.openfeature.contrib.providers.flagd.resolver.common.FlagdProviderEvent; import dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.FlagStore; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.Storage; @@ -20,13 +19,15 @@ import dev.openfeature.sdk.ImmutableMetadata; import dev.openfeature.sdk.ProviderEvaluation; import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.ProviderEventDetails; import dev.openfeature.sdk.Reason; +import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; import dev.openfeature.sdk.exceptions.GeneralError; import dev.openfeature.sdk.exceptions.ParseError; import dev.openfeature.sdk.exceptions.TypeMismatchError; +import dev.openfeature.sdk.internal.TriConsumer; import java.util.Map; -import java.util.function.Consumer; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -38,7 +39,7 @@ @Slf4j public class InProcessResolver implements Resolver { private final Storage flagStore; - private final Consumer onConnectionEvent; + private final TriConsumer onConnectionEvent; private final Operator operator; private final String scope; private final QueueSource queueSource; @@ -52,7 +53,8 @@ public class InProcessResolver implements Resolver { * @param onConnectionEvent lambda which handles changes in the * connection/stream */ - public InProcessResolver(FlagdOptions options, Consumer onConnectionEvent) { + public InProcessResolver( + FlagdOptions options, TriConsumer onConnectionEvent) { this.queueSource = getQueueSource(options); this.flagStore = new FlagStore(queueSource); this.onConnectionEvent = onConnectionEvent; @@ -73,14 +75,29 @@ public void init() throws Exception { switch (storageStateChange.getStorageState()) { case OK: log.debug("onConnectionEvent.accept ProviderEvent.PROVIDER_CONFIGURATION_CHANGED"); - onConnectionEvent.accept(new FlagdProviderEvent( + + var eventDetails = ProviderEventDetails.builder() + .flagsChanged(storageStateChange.getChangedFlagsKeys()) + .message("configuration changed") + .build(); + + onConnectionEvent.accept( ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, - storageStateChange.getChangedFlagsKeys(), - storageStateChange.getSyncMetadata())); + eventDetails, + storageStateChange.getSyncMetadata()); + log.debug("post onConnectionEvent.accept ProviderEvent.PROVIDER_CONFIGURATION_CHANGED"); break; + case STALE: + onConnectionEvent.accept(ProviderEvent.PROVIDER_ERROR, null, null); + break; case ERROR: - onConnectionEvent.accept(new FlagdProviderEvent(ProviderEvent.PROVIDER_ERROR)); + onConnectionEvent.accept( + ProviderEvent.PROVIDER_ERROR, + ProviderEventDetails.builder() + .errorCode(ErrorCode.PROVIDER_FATAL) + .build(), + null); break; default: log.warn(String.format( diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStore.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStore.java index eaa3dfa5f..b6bfff1ff 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStore.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStore.java @@ -11,6 +11,7 @@ import dev.openfeature.sdk.ImmutableStructure; import dev.openfeature.sdk.Structure; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -109,7 +110,7 @@ private void streamerListener(final QueueSource connector) throws InterruptedExc switch (payload.getType()) { case DATA: try { - List changedFlagsKeys; + List changedFlagsKeys = Collections.emptyList(); ParsingResult parsingResult = FlagParser.parseString(payload.getFlagData(), throwIfInvalid); Map flagMap = parsingResult.getFlags(); Map flagSetMetadataMap = parsingResult.getFlagSetMetadata(); @@ -133,13 +134,19 @@ private void streamerListener(final QueueSource connector) throws InterruptedExc // catch all exceptions and avoid stream listener interruptions log.warn("Invalid flag sync payload from connector", e); if (!stateBlockingQueue.offer(new StorageStateChange(StorageState.STALE))) { - log.warn("Failed to convey STALE status, queue is full"); + log.warn("Failed to convey TRANSIENT_ERROR status, queue is full"); } } break; case ERROR: + if (!stateBlockingQueue.offer(new StorageStateChange(StorageState.STALE))) { + log.warn("Failed to convey TRANSIENT_ERROR status, queue is full"); + } + break; + case SHUTDOWN: + shutdown(); if (!stateBlockingQueue.offer(new StorageStateChange(StorageState.ERROR))) { - log.warn("Failed to convey ERROR status, queue is full"); + log.warn("Failed to convey FATAL_ERROR status, queue is full"); } break; default: diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/StorageState.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/StorageState.java index c47670b7d..55b22dab4 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/StorageState.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/StorageState.java @@ -1,10 +1,10 @@ package dev.openfeature.contrib.providers.flagd.resolver.process.storage; -/** Satus of the storage. */ +/** Status of the storage. */ public enum StorageState { /** Storage is upto date and working as expected. */ OK, - /** Storage has gone stale(most recent sync failed). May get to OK status with next sync. */ + /** Storage has gone stale (most recent sync failed). May get to OK status with next sync. */ STALE, /** Storage is in an unrecoverable error stage. */ ERROR, diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayload.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayload.java index 071e51085..c31e5bd1b 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayload.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayload.java @@ -8,6 +8,9 @@ @AllArgsConstructor @Getter public class QueuePayload { + public static final QueuePayload ERROR = new QueuePayload(QueuePayloadType.ERROR); + public static final QueuePayload SHUTDOWN = new QueuePayload(QueuePayloadType.SHUTDOWN); + private final QueuePayloadType type; private final String flagData; private final Struct syncContext; @@ -15,4 +18,8 @@ public class QueuePayload { public QueuePayload(QueuePayloadType type, String flagData) { this(type, flagData, null); } + + public QueuePayload(QueuePayloadType type) { + this(type, null, null); + } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayloadType.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayloadType.java index 93675fb60..d9d1c5479 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayloadType.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/QueuePayloadType.java @@ -3,5 +3,6 @@ /** Payload type emitted by {@link QueueSource}. */ public enum QueuePayloadType { DATA, - ERROR + ERROR, + SHUTDOWN } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java index 1e2e043d7..5d245a764 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java @@ -19,6 +19,7 @@ import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; +import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -46,6 +47,7 @@ public class SyncStreamQueueSource implements QueueSource { private final boolean reinitializeOnError; private final FlagdOptions options; private final BlockingQueue outgoingQueue = new LinkedBlockingQueue<>(QUEUE_SIZE); + private final List fatalStatusCodes; private volatile GrpcComponents grpcComponents; /** @@ -76,6 +78,7 @@ public SyncStreamQueueSource(final FlagdOptions options) { providerId = options.getProviderId(); maxBackoffMs = options.getRetryBackoffMaxMs(); syncMetadataDisabled = options.isSyncMetadataDisabled(); + fatalStatusCodes = options.getFatalStatusCodes(); reinitializeOnError = options.isReinitializeOnError(); this.options = options; initializeChannelComponents(); @@ -93,6 +96,7 @@ protected SyncStreamQueueSource( providerId = options.getProviderId(); maxBackoffMs = options.getRetryBackoffMaxMs(); syncMetadataDisabled = options.isSyncMetadataDisabled(); + fatalStatusCodes = options.getFatalStatusCodes(); reinitializeOnError = options.isReinitializeOnError(); this.options = options; this.grpcComponents = new GrpcComponents(connectorMock, stubMock, blockingStubMock); @@ -155,12 +159,14 @@ public BlockingQueue getStreamQueue() { * @throws InterruptedException if stream can't be closed within deadline. */ public void shutdown() throws InterruptedException { + // Do not enqueue errors from here, as this method can be called externally, causing multiple shutdown signals // Use atomic compareAndSet to ensure shutdown is only executed once // This prevents race conditions when shutdown is called from multiple threads if (!shutdown.compareAndSet(false, true)) { log.debug("Shutdown already in progress or completed"); return; } + grpcComponents.channelConnector.shutdown(); } @@ -184,23 +190,41 @@ private void observeSyncStream() { } log.debug("Initializing sync stream request"); - SyncStreamObserver observer = new SyncStreamObserver(outgoingQueue, shouldThrottle); + SyncStreamObserver observer = new SyncStreamObserver(outgoingQueue); try { observer.metadata = getMetadata(); - } catch (Exception metaEx) { - // retry if getMetadata fails - String message = metaEx.getMessage(); - log.debug("Metadata request error: {}, will restart", message, metaEx); - enqueueError(String.format("Error in getMetadata request: %s", message)); + } catch (StatusRuntimeException metaEx) { + if (fatalStatusCodes.contains(metaEx.getStatus().getCode().name())) { + log.info( + "Fatal status code for metadata request: {}, not retrying", + metaEx.getStatus().getCode()); + shutdown(); + enqueue(QueuePayload.SHUTDOWN); + } else { + // retry for other status codes + String message = metaEx.getMessage(); + log.debug("Metadata request error: {}, will restart", message, metaEx); + enqueue(QueuePayload.ERROR); + } shouldThrottle.set(true); continue; } try { syncFlags(observer); - } catch (Exception ex) { - log.error("Unexpected sync stream exception, will restart.", ex); - enqueueError(String.format("Error in syncStream: %s", ex.getMessage())); + handleObserverError(observer); + } catch (StatusRuntimeException ex) { + if (fatalStatusCodes.contains(ex.getStatus().getCode().name())) { + log.info( + "Fatal status code during sync stream: {}, not retrying", + ex.getStatus().getCode()); + shutdown(); + enqueue(QueuePayload.SHUTDOWN); + } else { + // retry for other status codes + log.error("Unexpected sync stream exception, will restart.", ex); + enqueue(QueuePayload.ERROR); + } shouldThrottle.set(true); } } catch (InterruptedException ie) { @@ -267,26 +291,40 @@ private void syncFlags(SyncStreamObserver streamObserver) { streamObserver.done.await(); } - private void enqueueError(String message) { - enqueueError(outgoingQueue, message); + private void handleObserverError(SyncStreamObserver observer) throws InterruptedException { + if (observer.throwable == null) { + return; + } + + Throwable throwable = observer.throwable; + Status status = Status.fromThrowable(throwable); + String message = throwable.getMessage(); + if (fatalStatusCodes.contains(status.getCode().name())) { + shutdown(); + } else { + log.debug("Stream error: {}, will restart", message, throwable); + enqueue(QueuePayload.ERROR); + } + + // Set throttling flag to ensure backoff before retry + this.shouldThrottle.set(true); } - private static void enqueueError(BlockingQueue queue, String message) { - if (!queue.offer(new QueuePayload(QueuePayloadType.ERROR, message, null))) { - log.error("Failed to convey ERROR status, queue is full"); + private void enqueue(QueuePayload queuePayload) { + if (!outgoingQueue.offer(queuePayload)) { + log.error("Failed to convey {} status, queue is full", queuePayload.getType()); } } private static class SyncStreamObserver implements StreamObserver { private final BlockingQueue outgoingQueue; - private final AtomicBoolean shouldThrottle; private final Awaitable done = new Awaitable(); private Struct metadata; + private Throwable throwable; - public SyncStreamObserver(BlockingQueue outgoingQueue, AtomicBoolean shouldThrottle) { + public SyncStreamObserver(BlockingQueue outgoingQueue) { this.outgoingQueue = outgoingQueue; - this.shouldThrottle = shouldThrottle; } @Override @@ -303,16 +341,9 @@ public void onNext(SyncFlagsResponse syncFlagsResponse) { @Override public void onError(Throwable throwable) { - try { - String message = throwable != null ? throwable.getMessage() : "unknown"; - log.debug("Stream error: {}, will restart", message, throwable); - enqueueError(outgoingQueue, String.format("Error from stream: %s", message)); - - // Set throttling flag to ensure backoff before retry - this.shouldThrottle.set(true); - } finally { - done.wakeup(); - } + log.debug("Sync stream error received", throwable); + this.throwable = throwable; + done.wakeup(); } @Override diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolver.java index afb06120b..b09634088 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolver.java @@ -11,7 +11,6 @@ import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelBuilder; import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelConnector; -import dev.openfeature.contrib.providers.flagd.resolver.common.FlagdProviderEvent; import dev.openfeature.contrib.providers.flagd.resolver.common.QueueingStreamObserver; import dev.openfeature.contrib.providers.flagd.resolver.common.StreamResponseModel; import dev.openfeature.contrib.providers.flagd.resolver.rpc.cache.Cache; @@ -27,16 +26,20 @@ import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub; +import dev.openfeature.sdk.ErrorCode; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.ImmutableMetadata; import dev.openfeature.sdk.ProviderEvaluation; import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.ProviderEventDetails; +import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; import dev.openfeature.sdk.exceptions.FlagNotFoundError; import dev.openfeature.sdk.exceptions.GeneralError; import dev.openfeature.sdk.exceptions.OpenFeatureError; import dev.openfeature.sdk.exceptions.ParseError; import dev.openfeature.sdk.exceptions.TypeMismatchError; +import dev.openfeature.sdk.internal.TriConsumer; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; @@ -46,7 +49,6 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; import java.util.function.Function; import lombok.extern.slf4j.Slf4j; @@ -60,14 +62,16 @@ public final class RpcResolver implements Resolver { private static final int QUEUE_SIZE = 5; private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final AtomicBoolean successfulConnection = new AtomicBoolean(false); private final ChannelConnector connector; private final Cache cache; private final ResolveStrategy strategy; private final FlagdOptions options; private final LinkedBlockingQueue> incomingQueue; - private final Consumer onProviderEvent; + private final TriConsumer onProviderEvent; private final ServiceStub stub; private final ServiceBlockingStub blockingStub; + private final List fatalStatusCodes; /** * Resolves flag values using @@ -79,7 +83,9 @@ public final class RpcResolver implements Resolver { * @param onProviderEvent lambda which handles changes in the connection/stream */ public RpcResolver( - final FlagdOptions options, final Cache cache, final Consumer onProviderEvent) { + final FlagdOptions options, + final Cache cache, + final TriConsumer onProviderEvent) { this.cache = cache; this.strategy = ResolveFactory.getStrategy(options); this.options = options; @@ -89,13 +95,14 @@ public RpcResolver( this.stub = ServiceGrpc.newStub(this.connector.getChannel()).withWaitForReady(); this.blockingStub = ServiceGrpc.newBlockingStub(this.connector.getChannel()).withWaitForReady(); + this.fatalStatusCodes = options.getFatalStatusCodes(); } // testing only protected RpcResolver( final FlagdOptions options, final Cache cache, - final Consumer onProviderEvent, + final TriConsumer onProviderEvent, ServiceStub mockStub, ServiceBlockingStub mockBlockingStub, ChannelConnector connector) { @@ -107,6 +114,7 @@ protected RpcResolver( this.onProviderEvent = onProviderEvent; this.stub = mockStub; this.blockingStub = mockBlockingStub; + this.fatalStatusCodes = options.getFatalStatusCodes(); } /** @@ -341,20 +349,35 @@ private void observeEventStream() throws InterruptedException { final StreamResponseModel taken = incomingQueue.take(); if (taken.isComplete()) { log.debug("Event stream completed, will reconnect"); - this.handleErrorOrComplete(); + this.handleErrorOrComplete(false); // The stream is complete, we still try to reconnect break; } Throwable streamException = taken.getError(); if (streamException != null) { - log.debug( - "Exception in event stream connection, streamException {}, will reconnect", - streamException); - this.handleErrorOrComplete(); + if (streamException instanceof StatusRuntimeException + && fatalStatusCodes.contains(((StatusRuntimeException) streamException) + .getStatus() + .getCode() + .name()) + && !successfulConnection.get()) { + log.debug( + "Fatal error code received: {}", + ((StatusRuntimeException) streamException) + .getStatus() + .getCode()); + this.handleErrorOrComplete(true); + } else { + log.debug( + "Exception in event stream connection, streamException {}, will reconnect", + streamException); + this.handleErrorOrComplete(false); + } break; } + successfulConnection.set(true); final EventStreamResponse response = taken.getResponse(); log.debug("Got stream response: {}", response); @@ -395,7 +418,10 @@ private void handleConfigurationChangeEvent(EventStreamResponse value) { changedFlags.forEach(this.cache::remove); } - onProviderEvent.accept(new FlagdProviderEvent(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, changedFlags)); + onProviderEvent.accept( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + ProviderEventDetails.builder().flagsChanged(changedFlags).build(), + null); } /** @@ -403,16 +429,18 @@ private void handleConfigurationChangeEvent(EventStreamResponse value) { */ private void handleProviderReadyEvent() { log.debug("Emitting provider ready event"); - onProviderEvent.accept(new FlagdProviderEvent(ProviderEvent.PROVIDER_READY)); + onProviderEvent.accept(ProviderEvent.PROVIDER_READY, null, null); } /** * Handles provider error events by clearing the cache (if enabled) and notifying listeners of the error. */ - private void handleErrorOrComplete() { + private void handleErrorOrComplete(boolean fatal) { log.debug("Emitting provider error event"); + ErrorCode errorCode = fatal ? ErrorCode.PROVIDER_FATAL : ErrorCode.GENERAL; + var details = ProviderEventDetails.builder().errorCode(errorCode).build(); // complete is an error, logically...even if the server went down gracefully we need to reconnect. - onProviderEvent.accept(new FlagdProviderEvent(ProviderEvent.PROVIDER_ERROR)); + onProviderEvent.accept(ProviderEvent.PROVIDER_ERROR, details, null); } } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResourcesTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResourcesTest.java index fd7f55111..684caf0d9 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResourcesTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderSyncResourcesTest.java @@ -1,5 +1,6 @@ package dev.openfeature.contrib.providers.flagd; +import dev.openfeature.sdk.exceptions.FatalError; import dev.openfeature.sdk.exceptions.GeneralError; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -111,18 +112,52 @@ void callingInitialize_wakesUpWaitingThread() throws InterruptedException { @Timeout(2) @Test - void callingShutdown_wakesUpWaitingThreadWithException() throws InterruptedException { + void callingShutdownWithPreviousNonFatal_wakesUpWaitingThread_WithGeneralException() throws InterruptedException { final AtomicBoolean isWaiting = new AtomicBoolean(); final AtomicBoolean successfulTest = new AtomicBoolean(); + flagdProviderSyncResources.setFatal(false); + Thread waitingThread = new Thread(() -> { long start = System.currentTimeMillis(); isWaiting.set(true); - Assertions.assertThrows( - IllegalStateException.class, () -> flagdProviderSyncResources.waitForInitialization(10000)); + Assertions.assertThrows(GeneralError.class, () -> flagdProviderSyncResources.waitForInitialization(10000)); + + long end = System.currentTimeMillis(); + long duration = end - start; + var wait = MAX_TIME_TOLERANCE * 3; + successfulTest.set(duration < wait); + }); + waitingThread.start(); + + while (!isWaiting.get()) { + Thread.yield(); + } + + Thread.sleep(MAX_TIME_TOLERANCE); // waitingThread should have started waiting in the meantime + + flagdProviderSyncResources.shutdown(); + + waitingThread.join(); + + Assertions.assertTrue(successfulTest.get()); + } + + @Timeout(2) + @Test + void callingShutdownWithPreviousFatal_wakesUpWaitingThread_WithFatalException() throws InterruptedException { + final AtomicBoolean isWaiting = new AtomicBoolean(); + final AtomicBoolean successfulTest = new AtomicBoolean(); + flagdProviderSyncResources.setFatal(true); + + Thread waitingThread = new Thread(() -> { + long start = System.currentTimeMillis(); + isWaiting.set(true); + Assertions.assertThrows(FatalError.class, () -> flagdProviderSyncResources.waitForInitialization(10000)); long end = System.currentTimeMillis(); long duration = end - start; - successfulTest.set(duration < MAX_TIME_TOLERANCE * 2); + var wait = MAX_TIME_TOLERANCE * 3; + successfulTest.set(duration < wait); }); waitingThread.start(); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java index 115887002..2af572fe5 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java @@ -19,7 +19,6 @@ import com.google.protobuf.Struct; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelConnector; -import dev.openfeature.contrib.providers.flagd.resolver.common.FlagdProviderEvent; import dev.openfeature.contrib.providers.flagd.resolver.process.InProcessResolver; import dev.openfeature.contrib.providers.flagd.resolver.process.MockStorage; import dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag; @@ -45,9 +44,11 @@ import dev.openfeature.sdk.MutableStructure; import dev.openfeature.sdk.OpenFeatureAPI; import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.ProviderEventDetails; import dev.openfeature.sdk.Reason; import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; +import dev.openfeature.sdk.internal.TriConsumer; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.ArrayList; @@ -59,7 +60,6 @@ import java.util.Optional; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Function; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -556,11 +556,12 @@ void initializationAndShutdown() throws Exception { flagResolver.setAccessible(true); flagResolver.set(provider, resolverMock); - Method onProviderEvent = FlagdProvider.class.getDeclaredMethod("onProviderEvent", FlagdProviderEvent.class); + Method onProviderEvent = FlagdProvider.class.getDeclaredMethod( + "onProviderEvent", ProviderEvent.class, ProviderEventDetails.class, Structure.class); onProviderEvent.setAccessible(true); doAnswer((i) -> { - onProviderEvent.invoke(provider, new FlagdProviderEvent(ProviderEvent.PROVIDER_READY)); + onProviderEvent.invoke(provider, ProviderEvent.PROVIDER_READY, null, null); return null; }) .when(resolverMock) @@ -596,17 +597,16 @@ void contextEnrichment() throws Exception { // mock a resolver try (MockedConstruction mockResolver = mockConstruction(InProcessResolver.class, (mock, context) -> { - Consumer onConnectionEvent; + TriConsumer onConnectionEvent; // get a reference to the onConnectionEvent callback - onConnectionEvent = - (Consumer) context.arguments().get(1); + onConnectionEvent = (TriConsumer) + context.arguments().get(1); // when our mock resolver initializes, it runs the passed onConnectionEvent // callback doAnswer(invocation -> { - onConnectionEvent.accept( - new FlagdProviderEvent(ProviderEvent.PROVIDER_READY, metadata)); + onConnectionEvent.accept(ProviderEvent.PROVIDER_READY, null, metadata); return null; }) .when(mock) @@ -637,17 +637,16 @@ void updatesSyncMetadataWithCallback() throws Exception { // mock a resolver try (MockedConstruction mockResolver = mockConstruction(InProcessResolver.class, (mock, context) -> { - Consumer onConnectionEvent; + TriConsumer onConnectionEvent; // get a reference to the onConnectionEvent callback - onConnectionEvent = - (Consumer) context.arguments().get(1); + onConnectionEvent = (TriConsumer) + context.arguments().get(1); // when our mock resolver initializes, it runs the passed onConnectionEvent // callback doAnswer(invocation -> { - onConnectionEvent.accept( - new FlagdProviderEvent(ProviderEvent.PROVIDER_READY, metadata)); + onConnectionEvent.accept(ProviderEvent.PROVIDER_READY, null, metadata); return null; }) .when(mock) @@ -690,7 +689,7 @@ private FlagdProvider createProvider(ChannelConnector connector, ServiceBlocking private FlagdProvider createProvider( ChannelConnector connector, Cache cache, ServiceStub mockStub, ServiceBlockingStub mockBlockingStub) { final FlagdOptions flagdOptions = FlagdOptions.builder().build(); - final RpcResolver grpcResolver = new RpcResolver(flagdOptions, cache, (connectionEvent) -> {}); + final RpcResolver grpcResolver = new RpcResolver(flagdOptions, cache, (event, details, metadata) -> {}); try { Field resolver = RpcResolver.class.getDeclaredField("connector"); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/ProviderSteps.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/ProviderSteps.java index 27806f955..90d082292 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/ProviderSteps.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/ProviderSteps.java @@ -1,6 +1,7 @@ package dev.openfeature.contrib.providers.flagd.e2e.steps; import static io.restassured.RestAssured.when; +import static org.assertj.core.api.Assertions.assertThat; import dev.openfeature.contrib.providers.flagd.Config; import dev.openfeature.contrib.providers.flagd.FlagdOptions; @@ -9,10 +10,12 @@ import dev.openfeature.contrib.providers.flagd.e2e.State; import dev.openfeature.sdk.FeatureProvider; import dev.openfeature.sdk.OpenFeatureAPI; +import dev.openfeature.sdk.ProviderState; import io.cucumber.java.After; import io.cucumber.java.AfterAll; import io.cucumber.java.BeforeAll; import io.cucumber.java.en.Given; +import io.cucumber.java.en.Then; import io.cucumber.java.en.When; import java.io.File; import java.io.IOException; @@ -31,6 +34,7 @@ public class ProviderSteps extends AbstractSteps { public static final int UNAVAILABLE_PORT = 9999; + public static final int FORBIDDEN_PORT = 9212; static ComposeContainer container; static Path sharedTempDir; @@ -49,6 +53,7 @@ public static void beforeAll() throws IOException { .withExposedService("flagd", 8015, Wait.forListeningPort()) .withExposedService("flagd", 8080, Wait.forListeningPort()) .withExposedService("envoy", 9211, Wait.forListeningPort()) + .withExposedService("envoy", FORBIDDEN_PORT, Wait.forListeningPort()) .withStartupTimeout(Duration.ofSeconds(45)); container.start(); } @@ -85,6 +90,10 @@ public void setupProvider(String providerType) throws InterruptedException { } wait = false; break; + case "forbidden": + state.builder.port(container.getServicePort("envoy", FORBIDDEN_PORT)); + wait = false; + break; case "socket": this.state.providerType = ProviderType.SOCKET; String socketPath = @@ -188,4 +197,9 @@ public void the_flag_was_modded() { .then() .statusCode(200); } + + @Then("the client should be in {} state") + public void the_client_should_be_in_fatal_state(String clientState) { + assertThat(state.client.getProviderState()).isEqualTo(ProviderState.valueOf(clientState.toUpperCase())); + } } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/Utils.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/Utils.java index 7dca50533..626105ce4 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/Utils.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/Utils.java @@ -4,7 +4,10 @@ import dev.openfeature.contrib.providers.flagd.resolver.rpc.cache.CacheType; import dev.openfeature.sdk.Value; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper; public final class Utils { @@ -37,6 +40,10 @@ public static Object convert(String value, String type) throws ClassNotFoundExce } case "CacheType": return CacheType.valueOf(value.toUpperCase()).getValue(); + case "StringList": + return value.isEmpty() + ? List.of() + : Arrays.stream(value.split(",")).map(String::trim).collect(Collectors.toList()); case "Object": return Value.objectToValue(new ObjectMapper().readValue(value, Object.class)); } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java index 34c660702..04670b397 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java @@ -23,7 +23,6 @@ import dev.openfeature.contrib.providers.flagd.Config; import dev.openfeature.contrib.providers.flagd.FlagdOptions; -import dev.openfeature.contrib.providers.flagd.resolver.common.FlagdProviderEvent; import dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.MockConnector; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.StorageState; @@ -36,11 +35,15 @@ import dev.openfeature.sdk.MutableContext; import dev.openfeature.sdk.MutableStructure; import dev.openfeature.sdk.ProviderEvaluation; +import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.ProviderEventDetails; import dev.openfeature.sdk.Reason; +import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; import dev.openfeature.sdk.exceptions.GeneralError; import dev.openfeature.sdk.exceptions.ParseError; import dev.openfeature.sdk.exceptions.TypeMismatchError; +import dev.openfeature.sdk.internal.TriConsumer; import java.lang.reflect.Field; import java.time.Duration; import java.util.Collections; @@ -49,7 +52,6 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -59,7 +61,7 @@ void onError_delegatesToQueueSource() throws Exception { // given FlagdOptions options = FlagdOptions.builder().build(); // option value doesn't matter here SyncStreamQueueSource mockConnector = mock(SyncStreamQueueSource.class); - InProcessResolver resolver = new InProcessResolver(options, e -> {}); + InProcessResolver resolver = new InProcessResolver(options, (event, details, metadata) -> {}); // Inject mock connector java.lang.reflect.Field queueSourceField = InProcessResolver.class.getDeclaredField("queueSource"); @@ -109,12 +111,15 @@ void eventHandling() throws Throwable { final MutableStructure syncMetadata = new MutableStructure(); syncMetadata.add(key, val); - InProcessResolver inProcessResolver = getInProcessResolverWith( - new MockStorage(new HashMap<>(), sender), - connectionEvent -> receiver.offer(new StorageStateChange( - connectionEvent.isDisconnected() ? StorageState.ERROR : StorageState.OK, - connectionEvent.getFlagsChanged(), - connectionEvent.getSyncMetadata()))); + InProcessResolver inProcessResolver = + getInProcessResolverWith(new MockStorage(new HashMap<>(), sender), (event, details, metadata) -> { + boolean isDisconnected = + event == ProviderEvent.PROVIDER_ERROR || event == ProviderEvent.PROVIDER_STALE; + receiver.offer(new StorageStateChange( + isDisconnected ? StorageState.ERROR : StorageState.OK, + details != null ? details.getFlagsChanged() : Collections.emptyList(), + metadata)); + }); // when - init and emit events Thread initThread = new Thread(() -> { @@ -149,7 +154,7 @@ public void simpleBooleanResolving() throws Exception { flagMap.put("booleanFlag", BOOLEAN_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = @@ -168,7 +173,7 @@ public void simpleDoubleResolving() throws Exception { flagMap.put("doubleFlag", DOUBLE_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = @@ -187,7 +192,7 @@ public void fetchIntegerAsDouble() throws Exception { flagMap.put("doubleFlag", DOUBLE_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = @@ -206,7 +211,7 @@ public void fetchDoubleAsInt() throws Exception { flagMap.put("integerFlag", INT_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = @@ -225,7 +230,7 @@ public void simpleIntResolving() throws Exception { flagMap.put("integerFlag", INT_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = @@ -244,7 +249,7 @@ public void simpleObjectResolving() throws Exception { flagMap.put("objectFlag", OBJECT_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); Map typeDefault = new HashMap<>(); typeDefault.put("key", "0164"); @@ -270,7 +275,7 @@ public void missingFlag() throws Exception { final Map flagMap = new HashMap<>(); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when/then ProviderEvaluation missingFlag = @@ -285,7 +290,7 @@ public void disabledFlag() throws Exception { flagMap.put("disabledFlag", DISABLED_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when/then ProviderEvaluation disabledFlag = @@ -300,7 +305,7 @@ public void variantMismatchFlag() throws Exception { flagMap.put("mismatchFlag", VARIANT_MISMATCH_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when/then assertThrows(GeneralError.class, () -> { @@ -315,7 +320,7 @@ public void typeMismatchEvaluation() throws Exception { flagMap.put("booleanFlag", BOOLEAN_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when/then assertThrows(TypeMismatchError.class, () -> { @@ -330,7 +335,7 @@ public void booleanShorthandEvaluation() throws Exception { flagMap.put("shorthand", FLAG_WIH_SHORTHAND_TARGETING); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); ProviderEvaluation providerEvaluation = inProcessResolver.booleanEvaluation("shorthand", false, new ImmutableContext()); @@ -348,7 +353,7 @@ public void targetingMatchedEvaluationFlag() throws Exception { flagMap.put("stringFlag", FLAG_WIH_IF_IN_TARGET); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = inProcessResolver.stringEvaluation( @@ -367,7 +372,7 @@ public void targetingUnmatchedEvaluationFlag() throws Exception { flagMap.put("stringFlag", FLAG_WIH_IF_IN_TARGET); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = inProcessResolver.stringEvaluation( @@ -386,7 +391,7 @@ public void explicitTargetingKeyHandling() throws NoSuchFieldException, IllegalA flagMap.put("stringFlag", FLAG_WITH_TARGETING_KEY); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when ProviderEvaluation providerEvaluation = @@ -405,7 +410,7 @@ public void targetingErrorEvaluationFlag() throws Exception { flagMap.put("targetingErrorFlag", FLAG_WIH_INVALID_TARGET); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), (connectionEvent) -> {}); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}); // when/then assertThrows(ParseError.class, () -> { @@ -440,7 +445,7 @@ void selectorIsAddedToFlagMetadata() throws Exception { flagMap.put("flag", INT_FLAG); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), connectionEvent -> {}, "selector"); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}, "selector"); // when ProviderEvaluation providerEvaluation = @@ -460,7 +465,7 @@ void selectorIsOverwrittenByFlagMetadata() throws Exception { flagMap.put("flag", new FeatureFlag("stage", "loop", stringVariants, "", flagMetadata)); InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap), connectionEvent -> {}, "selector"); + getInProcessResolverWith(new MockStorage(flagMap), (event, details, metadata) -> {}, "selector"); // when ProviderEvaluation providerEvaluation = @@ -481,8 +486,8 @@ void flagSetMetadataIsAddedToEvaluation() throws Exception { final Map flagSetMetadata = new HashMap<>(); flagSetMetadata.put("flagSetMetadata", "metadata"); - InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap, flagSetMetadata), connectionEvent -> {}, "selector"); + InProcessResolver inProcessResolver = getInProcessResolverWith( + new MockStorage(flagMap, flagSetMetadata), (event, details, metadata) -> {}, "selector"); // when ProviderEvaluation providerEvaluation = @@ -502,8 +507,8 @@ void flagSetMetadataIsAddedToFailingEvaluation() throws Exception { final Map flagSetMetadata = new HashMap<>(); flagSetMetadata.put("flagSetMetadata", "metadata"); - InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap, flagSetMetadata), connectionEvent -> {}, "selector"); + InProcessResolver inProcessResolver = getInProcessResolverWith( + new MockStorage(flagMap, flagSetMetadata), (event, details, metadata) -> {}, "selector"); // when ProviderEvaluation providerEvaluation = @@ -526,8 +531,8 @@ void flagSetMetadataIsOverwrittenByFlagMetadataToEvaluation() throws Exception { final Map flagSetMetadata = new HashMap<>(); flagSetMetadata.put("key", "unexpected"); - InProcessResolver inProcessResolver = - getInProcessResolverWith(new MockStorage(flagMap, flagSetMetadata), connectionEvent -> {}, "selector"); + InProcessResolver inProcessResolver = getInProcessResolverWith( + new MockStorage(flagMap, flagSetMetadata), (event, details, metadata) -> {}, "selector"); // when ProviderEvaluation providerEvaluation = @@ -541,12 +546,13 @@ void flagSetMetadataIsOverwrittenByFlagMetadataToEvaluation() throws Exception { private InProcessResolver getInProcessResolverWith(final FlagdOptions options, final MockStorage storage) throws NoSuchFieldException, IllegalAccessException { - final InProcessResolver resolver = new InProcessResolver(options, connectionEvent -> {}); + final InProcessResolver resolver = new InProcessResolver(options, (event, details, metadata) -> {}); return injectFlagStore(resolver, storage); } private InProcessResolver getInProcessResolverWith( - final MockStorage storage, final Consumer onConnectionEvent) + final MockStorage storage, + final TriConsumer onConnectionEvent) throws NoSuchFieldException, IllegalAccessException { final InProcessResolver resolver = @@ -555,7 +561,9 @@ private InProcessResolver getInProcessResolverWith( } private InProcessResolver getInProcessResolverWith( - final MockStorage storage, final Consumer onConnectionEvent, String selector) + final MockStorage storage, + final TriConsumer onConnectionEvent, + String selector) throws NoSuchFieldException, IllegalAccessException { final InProcessResolver resolver = new InProcessResolver( diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStoreTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStoreTest.java index 86ca298e3..e58b4eb3f 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStoreTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/FlagStoreTest.java @@ -64,14 +64,14 @@ void connectorHandling() throws Exception { }); assertTimeoutPreemptively(Duration.ofMillis(maxDelay), () -> { - assertEquals(StorageState.ERROR, states.take().getStorageState()); + assertEquals(StorageState.STALE, states.take().getStorageState()); }); // Shutdown handling store.shutdown(); assertTimeoutPreemptively(Duration.ofMillis(maxDelay), () -> { - assertEquals(StorageState.ERROR, states.take().getStorageState()); + assertEquals(StorageState.STALE, states.take().getStorageState()); }); } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java index ce50a4d20..53d56b14f 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java @@ -27,79 +27,12 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.stubbing.Answer; class SyncStreamQueueSourceTest { - @Test - void reinitializeChannelComponents_reinitializesWhenEnabled() throws InterruptedException { - FlagdOptions options = FlagdOptions.builder().reinitializeOnError(true).build(); - ChannelConnector initialConnector = mock(ChannelConnector.class); - FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); - FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); - - try { - // save reference to old GrpcComponents - Object oldComponents = getPrivateField(queueSource, "grpcComponents"); - queueSource.reinitializeChannelComponents(); - Object newComponents = getPrivateField(queueSource, "grpcComponents"); - // should have replaced grpcComponents - assertNotNull(newComponents); - org.junit.jupiter.api.Assertions.assertNotSame(oldComponents, newComponents); - } finally { - queueSource.shutdown(); - } - } - - @Test - void reinitializeChannelComponents_doesNothingWhenDisabled() throws InterruptedException { - FlagdOptions options = FlagdOptions.builder().reinitializeOnError(false).build(); - ChannelConnector initialConnector = mock(ChannelConnector.class); - FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); - FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); - - try { - Object oldComponents = getPrivateField(queueSource, "grpcComponents"); - queueSource.reinitializeChannelComponents(); - Object newComponents = getPrivateField(queueSource, "grpcComponents"); - // should NOT have replaced grpcComponents - org.junit.jupiter.api.Assertions.assertSame(oldComponents, newComponents); - } finally { - queueSource.shutdown(); - } - } - - @Test - void reinitializeChannelComponents_doesNothingWhenShutdown() throws InterruptedException { - FlagdOptions options = FlagdOptions.builder().reinitializeOnError(true).build(); - ChannelConnector initialConnector = mock(ChannelConnector.class); - FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); - FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); - - queueSource.shutdown(); - Object oldComponents = getPrivateField(queueSource, "grpcComponents"); - queueSource.reinitializeChannelComponents(); - Object newComponents = getPrivateField(queueSource, "grpcComponents"); - // should NOT have replaced grpcComponents - org.junit.jupiter.api.Assertions.assertSame(oldComponents, newComponents); - } - // helper to access private fields via reflection - private static Object getPrivateField(Object instance, String fieldName) { - try { - java.lang.reflect.Field field = instance.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - return field.get(instance); - } catch (Exception e) { - throw new RuntimeException(e); - } - } private ChannelConnector mockConnector; private FlagSyncServiceBlockingStub blockingStub; @@ -108,6 +41,7 @@ private static Object getPrivateField(Object instance, String fieldName) { private FlagSyncServiceStub asyncErrorStub; private StreamObserver observer; private CountDownLatch latch; // used to wait for observer to be initialized + private SyncStreamQueueSource queueSource; @BeforeEach @SuppressWarnings("deprecation") @@ -170,12 +104,83 @@ public void setup() throws Exception { .syncFlags(any(SyncFlagsRequest.class), any()); // mock the initialize } + @AfterEach + public void tearDown() throws Exception { + queueSource.shutdown(); + } + + @Test + void reinitializeChannelComponents_reinitializesWhenEnabled() throws InterruptedException { + FlagdOptions options = FlagdOptions.builder().reinitializeOnError(true).build(); + ChannelConnector initialConnector = mock(ChannelConnector.class); + FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); + FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); + queueSource = new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); + + try { + // save reference to old GrpcComponents + Object oldComponents = getPrivateField(queueSource, "grpcComponents"); + queueSource.reinitializeChannelComponents(); + Object newComponents = getPrivateField(queueSource, "grpcComponents"); + // should have replaced grpcComponents + assertNotNull(newComponents); + org.junit.jupiter.api.Assertions.assertNotSame(oldComponents, newComponents); + } finally { + queueSource.shutdown(); + } + } + + @Test + void reinitializeChannelComponents_doesNothingWhenDisabled() throws InterruptedException { + FlagdOptions options = FlagdOptions.builder().reinitializeOnError(false).build(); + ChannelConnector initialConnector = mock(ChannelConnector.class); + FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); + FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); + queueSource = new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); + + try { + Object oldComponents = getPrivateField(queueSource, "grpcComponents"); + queueSource.reinitializeChannelComponents(); + Object newComponents = getPrivateField(queueSource, "grpcComponents"); + // should NOT have replaced grpcComponents + org.junit.jupiter.api.Assertions.assertSame(oldComponents, newComponents); + } finally { + queueSource.shutdown(); + } + } + + @Test + void reinitializeChannelComponents_doesNothingWhenShutdown() throws InterruptedException { + FlagdOptions options = FlagdOptions.builder().reinitializeOnError(true).build(); + ChannelConnector initialConnector = mock(ChannelConnector.class); + FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); + FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); + queueSource = new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); + + queueSource.shutdown(); + Object oldComponents = getPrivateField(queueSource, "grpcComponents"); + queueSource.reinitializeChannelComponents(); + Object newComponents = getPrivateField(queueSource, "grpcComponents"); + // should NOT have replaced grpcComponents + org.junit.jupiter.api.Assertions.assertSame(oldComponents, newComponents); + } + // helper to access private fields via reflection + private static Object getPrivateField(Object instance, String fieldName) { + try { + java.lang.reflect.Field field = instance.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(instance); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + @Test void syncInitError_DoesNotBusyWait() throws Exception { // make sure we do not spin in a busy loop on immediately errors int maxBackoffMs = 1000; - SyncStreamQueueSource queueSource = new SyncStreamQueueSource( + queueSource = new SyncStreamQueueSource( FlagdOptions.builder().retryBackoffMaxMs(maxBackoffMs).build(), mockConnector, syncErrorStub, @@ -200,7 +205,7 @@ void asyncInitError_DoesNotBusyWait() throws Exception { // make sure we do not spin in a busy loop on async errors int maxBackoffMs = 1000; - SyncStreamQueueSource queueSource = new SyncStreamQueueSource( + queueSource = new SyncStreamQueueSource( FlagdOptions.builder().retryBackoffMaxMs(maxBackoffMs).build(), mockConnector, asyncErrorStub, @@ -222,8 +227,7 @@ void asyncInitError_DoesNotBusyWait() throws Exception { @Test void onNextEnqueuesDataPayload() throws Exception { - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); + queueSource = new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); latch = new CountDownLatch(1); queueSource.init(); latch.await(); @@ -245,7 +249,7 @@ void onNextEnqueuesDataPayload() throws Exception { @SuppressWarnings("deprecation") void onNextEnqueuesDataPayloadMetadataDisabled() throws Exception { // disable GetMetadata call - SyncStreamQueueSource queueSource = new SyncStreamQueueSource( + queueSource = new SyncStreamQueueSource( FlagdOptions.builder().syncMetadataDisabled(true).build(), mockConnector, stub, blockingStub); latch = new CountDownLatch(1); queueSource.init(); @@ -269,8 +273,7 @@ void onNextEnqueuesDataPayloadMetadataDisabled() throws Exception { @Test void onNextEnqueuesDataPayloadWithSyncContext() throws Exception { // disable GetMetadata call - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); + queueSource = new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); latch = new CountDownLatch(1); queueSource.init(); latch.await(); @@ -292,8 +295,7 @@ void onNextEnqueuesDataPayloadWithSyncContext() throws Exception { @Test void onErrorEnqueuesDataPayload() throws Exception { - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); + queueSource = new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); latch = new CountDownLatch(1); queueSource.init(); latch.await(); @@ -314,8 +316,7 @@ void onErrorEnqueuesDataPayload() throws Exception { @Test void onCompletedEnqueuesDataPayload() throws Exception { - SyncStreamQueueSource queueSource = - new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); + queueSource = new SyncStreamQueueSource(FlagdOptions.builder().build(), mockConnector, stub, blockingStub); latch = new CountDownLatch(1); queueSource.init(); latch.await(); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolverTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolverTest.java index 119f9e2e6..955d0fa2b 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolverTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/rpc/RpcResolverTest.java @@ -3,7 +3,7 @@ import static org.awaitility.Awaitility.await; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -13,14 +13,15 @@ import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelConnector; -import dev.openfeature.contrib.providers.flagd.resolver.common.FlagdProviderEvent; import dev.openfeature.contrib.providers.flagd.resolver.common.QueueingStreamObserver; import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub; import dev.openfeature.sdk.ProviderEvent; +import dev.openfeature.sdk.ProviderEventDetails; +import dev.openfeature.sdk.Structure; +import dev.openfeature.sdk.internal.TriConsumer; import java.util.concurrent.CountDownLatch; -import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.invocation.InvocationOnMock; @@ -31,15 +32,16 @@ class RpcResolverTest { private ServiceBlockingStub blockingStub; private ServiceStub stub; private QueueingStreamObserver observer; - private Consumer consumer; + private TriConsumer consumer; private CountDownLatch latch; // used to wait for observer to be initialized + @SuppressWarnings("unchecked") @BeforeEach public void init() throws Exception { latch = new CountDownLatch(1); observer = null; - consumer = mock(Consumer.class); - doNothing().when(consumer).accept(any()); + consumer = mock(TriConsumer.class); + doNothing().when(consumer).accept(any(), any(), any()); blockingStub = mock(ServiceBlockingStub.class); @@ -74,8 +76,7 @@ void onNextWithReadyRunsConsumerWithReady() throws Exception { .build()); // should run consumer with payload - await().untilAsserted(() -> - verify(consumer).accept(argThat((arg) -> arg.getEvent() == ProviderEvent.PROVIDER_READY))); + await().untilAsserted(() -> verify(consumer).accept(eq(ProviderEvent.PROVIDER_READY), any(), any())); // should NOT have restarted the stream (1 call) verify(stub, times(1)).eventStream(any(), any()); } @@ -95,8 +96,8 @@ void onNextWithChangedRunsConsumerWithChanged() throws Exception { // should run consumer with payload verify(stub, times(1)).eventStream(any(), any()); // should have restarted the stream (2 calls) - await().untilAsserted(() -> verify(consumer) - .accept(argThat((arg) -> arg.getEvent() == ProviderEvent.PROVIDER_CONFIGURATION_CHANGED))); + await().untilAsserted( + () -> verify(consumer).accept(eq(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED), any(), any())); } @Test @@ -110,8 +111,7 @@ void onCompletedRerunsStreamWithError() throws Exception { observer.onCompleted(); // should run consumer with error - await().untilAsserted(() -> - verify(consumer).accept(argThat((arg) -> arg.getEvent() == ProviderEvent.PROVIDER_ERROR))); + await().untilAsserted(() -> verify(consumer).accept(eq(ProviderEvent.PROVIDER_ERROR), any(), any())); // should have restarted the stream (2 calls) await().untilAsserted(() -> verify(stub, times(2)).eventStream(any(), any())); } @@ -127,8 +127,7 @@ void onErrorRunsConsumerWithError() throws Exception { observer.onError(new Exception("fake error")); // should run consumer with error - await().untilAsserted(() -> - verify(consumer).accept(argThat((arg) -> arg.getEvent() == ProviderEvent.PROVIDER_ERROR))); + await().untilAsserted(() -> verify(consumer).accept(eq(ProviderEvent.PROVIDER_ERROR), any(), any())); // should have restarted the stream (2 calls) await().untilAsserted(() -> verify(stub, times(2)).eventStream(any(), any())); } diff --git a/providers/flagd/test-harness b/providers/flagd/test-harness index b62f5dbe8..b0057abde 160000 --- a/providers/flagd/test-harness +++ b/providers/flagd/test-harness @@ -1 +1 @@ -Subproject commit b62f5dbe860ecf4f36ec757dfdc0b38f7b3dec6e +Subproject commit b0057abde5d84272d6dd91f4737655c9d6cead15