diff --git a/cloudfoundry-client-reactor/src/main/java/org/cloudfoundry/reactor/tokenprovider/AbstractUaaTokenProvider.java b/cloudfoundry-client-reactor/src/main/java/org/cloudfoundry/reactor/tokenprovider/AbstractUaaTokenProvider.java index 09c23487c0..6cfc85a486 100644 --- a/cloudfoundry-client-reactor/src/main/java/org/cloudfoundry/reactor/tokenprovider/AbstractUaaTokenProvider.java +++ b/cloudfoundry-client-reactor/src/main/java/org/cloudfoundry/reactor/tokenprovider/AbstractUaaTokenProvider.java @@ -80,7 +80,7 @@ public abstract class AbstractUaaTokenProvider implements TokenProvider { private final ConcurrentMap refreshTokenStreams = new ConcurrentHashMap<>(1); - private final ConcurrentMap> refreshTokens = + private final ConcurrentMap refreshTokens = new ConcurrentHashMap<>(1); /** @@ -116,7 +116,10 @@ public final Mono getToken(ConnectionContext connectionContext) { @Override public void invalidate(ConnectionContext connectionContext) { - this.accessTokens.put(connectionContext, token(connectionContext)); + String refreshToken = this.refreshTokens.remove(connectionContext); + if (refreshToken != null) { + this.accessTokens.put(connectionContext, token(connectionContext, refreshToken)); + } } /** @@ -133,6 +136,30 @@ public void invalidate(ConnectionContext connectionContext) { */ abstract void tokenRequestTransformer(HttpClientRequest request, HttpClientForm form); + private Mono token(ConnectionContext connectionContext) { + Mono token = + primaryToken(connectionContext) + .doOnSubscribe(s -> LOGGER.debug("Negotiating using token provider")); + + return cacheResult(connectionContext, token); + } + + private Mono token(ConnectionContext connectionContext, String refreshToken) { + Mono token = + refreshToken(connectionContext, refreshToken) + .doOnSubscribe(s -> LOGGER.debug("Negotiating using refresh token")) + // fall back to primary token in case the refresh_token grant fails + // (expired, revoked, ...) + .switchIfEmpty( + primaryToken(connectionContext) + .doOnSubscribe( + s -> + LOGGER.debug( + "Falling back to token provider"))); + + return cacheResult(connectionContext, token); + } + private static String extractAccessToken(Map payload) { String accessToken = payload.get(ACCESS_TOKEN); @@ -227,8 +254,7 @@ private Consumer> extractRefreshToken(ConnectionContext conn }); } - this.refreshTokens.put( - connectionContext, Mono.just(refreshToken)); + this.refreshTokens.put(connectionContext, refreshToken); getRefreshTokenStream(connectionContext) .sink .emitNext(refreshToken, FAIL_FAST); @@ -297,30 +323,16 @@ private void setAuthorization(HttpHeaders headers) { headers.set(AUTHORIZATION, String.format("Basic %s", encoded)); } - private Mono token(ConnectionContext connectionContext) { - Mono cached = - this.refreshTokens - .getOrDefault(connectionContext, Mono.empty()) - .flatMap( - refreshToken -> - refreshToken(connectionContext, refreshToken) - .doOnSubscribe( - s -> - LOGGER.debug( - "Negotiating using refresh" - + " token"))) - .switchIfEmpty( - primaryToken(connectionContext) - .doOnSubscribe( - s -> - LOGGER.debug( - "Negotiating using token" - + " provider"))); - + /** + * Cache the given mono. If {@link ConnectionContext#getCacheDuration()} is not null, use that + * as the cache TTL. Otherwise, cache indefinitely. + */ + private static Mono cacheResult( + ConnectionContext connectionContext, Mono token) { return connectionContext .getCacheDuration() - .map(cached::cache) - .orElseGet(cached::cache) + .map(token::cache) + .orElseGet(token::cache) .checkpoint(); }