diff --git a/src/main/java/com/amazonaws/secretsmanager/caching/SecretCache.java b/src/main/java/com/amazonaws/secretsmanager/caching/SecretCache.java index d0c2672..e0d5986 100644 --- a/src/main/java/com/amazonaws/secretsmanager/caching/SecretCache.java +++ b/src/main/java/com/amazonaws/secretsmanager/caching/SecretCache.java @@ -128,11 +128,26 @@ private SecretCacheItem getCachedSecret(final String secretId) { * @return The string secret */ public String getSecretString(final String secretId) { + return getSecretString(secretId, null, null); + } + + /** + * Retrieve and cache a secret string from AWS Secrets Manager. + * + * @param secretId the secret ID of the desired secret. + * @param versionId the version ID of the desired secret (optional, can be null). + * @param versionStage the version stage of the desired secret (optional, can be null). + * + * @return The secret string for the desired secret. + */ + public String getSecretString(final String secretId, final String versionId, final String versionStage) { SecretCacheItem secret = this.getCachedSecret(secretId); - GetSecretValueResponse gsv = secret.getSecretValue(); - if (null == gsv) { + GetSecretValueResponse gsv = secret.getSecretValue(versionId, versionStage); + + if (gsv == null) { return null; } + return gsv.secretString(); } @@ -143,11 +158,26 @@ public String getSecretString(final String secretId) { * @return The binary secret */ public ByteBuffer getSecretBinary(final String secretId) { + return getSecretBinary(secretId, null, null); + } + + /** + * Retrieve and cache a secret binary from AWS Secrets Manager. + * + * @param secretId the secret ID of the desired secret. + * @param versionId the version ID of the desired secret (optional, can be null). + * @param versionStage the version stage of the desired secret (optional, can be null). + * + * @return The secret binary for the desired secret. + */ + public ByteBuffer getSecretBinary(final String secretId, final String versionId, final String versionStage) { SecretCacheItem secret = this.getCachedSecret(secretId); - GetSecretValueResponse gsv = secret.getSecretValue(); - if (null == gsv) { + GetSecretValueResponse gsv = secret.getSecretValue(versionId, versionStage); + + if (gsv == null) { return null; } + return gsv.secretBinary().asByteBuffer(); } diff --git a/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheItem.java b/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheItem.java index 9c0ed78..74a0331 100644 --- a/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheItem.java +++ b/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheItem.java @@ -13,6 +13,8 @@ package com.amazonaws.secretsmanager.caching.cache; +import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; @@ -115,16 +117,32 @@ protected DescribeSecretResponse executeRefresh() { * The result of the Describe Secret request to AWS Secrets Manager. * @return The cached secret version. */ - private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse) { + private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse, String versionId, String versionStage) { if (null == describeResponse) { return null; } if (null == describeResponse.versionIdsToStages()) { return null; } - Optional currentVersionId = describeResponse.versionIdsToStages().entrySet() - .stream() - .filter(Objects::nonNull) - .filter(x -> x.getValue() != null) - .filter(x -> x.getValue().contains(this.config.getVersionStage())) - .map(x -> x.getKey()) - .findFirst(); + + Optional currentVersionId = Optional.empty(); + + for (Map.Entry> entry : describeResponse.versionIdsToStages().entrySet()) { + if (entry == null) { + continue; + } + + if (entry.getValue() == null) { + continue; + } + + if (versionId != null && versionId.equals(entry.getKey())) { + currentVersionId = Optional.of(versionId); + break; + } + + if ((versionStage != null && entry.getValue().contains(versionStage)) || entry.getValue().contains(config.getVersionStage())) { + currentVersionId = Optional.of(entry.getKey()); + break; + } + } + if (currentVersionId.isPresent()) { SecretCacheVersion version = versions.get(currentVersionId.get()); if (null == version) { @@ -134,6 +152,7 @@ private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse) { } return version; } + return null; } @@ -146,8 +165,26 @@ private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse) { */ @Override protected GetSecretValueResponse getSecretValue(DescribeSecretResponse describeResponse) { - SecretCacheVersion version = getVersion(describeResponse); - if (null == version) { return null; } + return getSecretValue(describeResponse, null, null); + } + + /** + * Return the cached GetSecretValue result. + * + * @param describeResponse the DescribeSecret result. + * @param versionId the version ID of the desired secret (optional, can be null). + * @param versionStage the version stage of the desired secret (optional, can be null). + * + * @return The cached GetSecretValue result. + */ + @Override + protected GetSecretValueResponse getSecretValue(DescribeSecretResponse describeResponse, String versionId, String versionStage) { + SecretCacheVersion version = getVersion(describeResponse, versionId, versionStage); + + if (version == null) { + return null; + } + return version.getSecretValue(); } diff --git a/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheObject.java b/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheObject.java index 3d626c4..c7bc995 100644 --- a/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheObject.java +++ b/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheObject.java @@ -111,6 +111,17 @@ public SecretCacheObject(final String secretId, */ protected abstract GetSecretValueResponse getSecretValue(T result); + /** + * Execute the actual refresh of the cached secret state. + * + * @param result the GetSecretValue or DescribeSecret result. + * @param versionId the version ID of the desired secret (optional, can be null). + * @param versionStage the version stage of the desired secret (optional, can be null). + * + * @return The cached GetSecretValue result based on the current cached state. + */ + protected abstract GetSecretValueResponse getSecretValue(T result, String versionId, String versionStage); + public abstract boolean equals(Object obj); public abstract int hashCode(); public abstract String toString(); @@ -236,13 +247,28 @@ public boolean refreshNow() throws InterruptedException { * @return The cached GetSecretValue result. */ public GetSecretValueResponse getSecretValue() { + return getSecretValue(null, null); + } + + /** + * Return the cached GetSecretValue result. + * + * @param versionId the version ID of the desired secret (optional, can be null). + * @param versionStage the version stage of the desired secret (optional, can be null). + * + * @return The cached GetSecretValue result. + */ + public GetSecretValueResponse getSecretValue(String versionId, String versionStage) { synchronized (lock) { refresh(); - if (null == this.data) { - if (null != this.exception) { throw this.exception; } + + if (this.data == null) { + if (this.exception != null) { + throw this.exception; + } } - return this.getSecretValue(this.getResult()); + return this.getSecretValue(this.getResult(), versionId, versionStage); } } diff --git a/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheVersion.java b/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheVersion.java index ec98dbf..20db42b 100644 --- a/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheVersion.java +++ b/src/main/java/com/amazonaws/secretsmanager/caching/cache/SecretCacheVersion.java @@ -98,4 +98,18 @@ protected GetSecretValueResponse getSecretValue(GetSecretValueResponse gsvResult return gsvResult; } + /** + * Return the cached GetSecretValue result. + * + * @param gsvResult the GetSecretValue or DescribeSecret result. + * @param versionId the version ID of the desired secret (optional, can be null). + * @param versionStage the version stage of the desired secret (optional, can be null). + * + * @return The cached GetSecretValue result. + */ + @Override + protected GetSecretValueResponse getSecretValue(GetSecretValueResponse gsvResult, String versionId, String versionStage) { + return gsvResult; + } + } diff --git a/src/test/java/com/amazonaws/secretsmanager/caching/SecretCacheTest.java b/src/test/java/com/amazonaws/secretsmanager/caching/SecretCacheTest.java index 4f4e93b..13f008e 100644 --- a/src/test/java/com/amazonaws/secretsmanager/caching/SecretCacheTest.java +++ b/src/test/java/com/amazonaws/secretsmanager/caching/SecretCacheTest.java @@ -135,6 +135,61 @@ public void basicSecretCacheTest() { sc.close(); } + @Test + public void basicSecretCacheVersionIdTest() { + final String secret = "basicSecretCacheTest"; + Map> versionMap = new HashMap>(); + versionMap.put("versionId", Arrays.asList("AWSCURRENT")); + versionMap.put("otherVersionId", Arrays.asList("AWSCURRENT")); + Mockito.when(describeSecretResponse.versionIdsToStages()).thenReturn(versionMap); + GetSecretValueResponse.Builder resBuilder = GetSecretValueResponse.builder().secretString(secret) + .secretBinary(SdkBytes.fromByteArray(secret.getBytes())); + getSecretValueResponse = resBuilder.build(); + + Mockito.when(asm.describeSecret(Mockito.any(DescribeSecretRequest.class))).thenReturn(describeSecretResponse); + Mockito.when(asm.getSecretValue(Mockito.any(GetSecretValueRequest.class))).thenReturn(getSecretValueResponse); + + SecretCache sc = new SecretCache(asm); + + // Request the secret multiple times and verify the correct result + repeat(10, n -> Assert.assertEquals(sc.getSecretString("", "otherVersionId", null), secret)); + + // Verify that multiple requests did not call the API + Mockito.verify(asm, Mockito.times(1)).describeSecret(Mockito.any(DescribeSecretRequest.class)); + Mockito.verify(asm, Mockito.times(1)).getSecretValue(Mockito.any(GetSecretValueRequest.class)); + + repeat(10, n -> Assert.assertEquals(sc.getSecretBinary("", "otherVersionId", null), + ByteBuffer.wrap(secret.getBytes()))); + sc.close(); + } + + @Test + public void basicSecretCacheVersionStageTest() { + final String secret = "basicSecretCacheTest"; + Map> versionMap = new HashMap>(); + versionMap.put("versionId", Arrays.asList("AWSCURRENT")); + Mockito.when(describeSecretResponse.versionIdsToStages()).thenReturn(versionMap); + GetSecretValueResponse.Builder resBuilder = GetSecretValueResponse.builder().secretString(secret) + .secretBinary(SdkBytes.fromByteArray(secret.getBytes())); + getSecretValueResponse = resBuilder.build(); + + Mockito.when(asm.describeSecret(Mockito.any(DescribeSecretRequest.class))).thenReturn(describeSecretResponse); + Mockito.when(asm.getSecretValue(Mockito.any(GetSecretValueRequest.class))).thenReturn(getSecretValueResponse); + + SecretCache sc = new SecretCache(asm); + + // Request the secret multiple times and verify the correct result + repeat(10, n -> Assert.assertEquals(sc.getSecretString("", null, "AWSCURRENT"), secret)); + + // Verify that multiple requests did not call the API + Mockito.verify(asm, Mockito.times(1)).describeSecret(Mockito.any(DescribeSecretRequest.class)); + Mockito.verify(asm, Mockito.times(1)).getSecretValue(Mockito.any(GetSecretValueRequest.class)); + + repeat(10, n -> Assert.assertEquals(sc.getSecretBinary("", null, "AWSCURRENT"), + ByteBuffer.wrap(secret.getBytes()))); + sc.close(); + } + @Test public void hookSecretCacheTest() { final String secret = "hookSecretCacheTest";