Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,26 @@ private SecretCacheItem getCachedSecret(final String secretId) {
* @return The string secret
*/
public String getSecretString(final String secretId) {
return getSecretString(secretId, null, null);
}

/**
* Retrieve and cache a secret string from AWS Secrets Manager.
*
* @param secretId the secret ID of the desired secret.
* @param versionId the version ID of the desired secret (optional, can be null).
* @param versionStage the version stage of the desired secret (optional, can be null).
*
* @return The secret string for the desired secret.
*/
public String getSecretString(final String secretId, final String versionId, final String versionStage) {
SecretCacheItem secret = this.getCachedSecret(secretId);
GetSecretValueResponse gsv = secret.getSecretValue();
if (null == gsv) {
GetSecretValueResponse gsv = secret.getSecretValue(versionId, versionStage);

if (gsv == null) {
return null;
}

return gsv.secretString();
}

Expand All @@ -143,11 +158,26 @@ public String getSecretString(final String secretId) {
* @return The binary secret
*/
public ByteBuffer getSecretBinary(final String secretId) {
return getSecretBinary(secretId, null, null);
}

/**
* Retrieve and cache a secret binary from AWS Secrets Manager.
*
* @param secretId the secret ID of the desired secret.
* @param versionId the version ID of the desired secret (optional, can be null).
* @param versionStage the version stage of the desired secret (optional, can be null).
*
* @return The secret binary for the desired secret.
*/
public ByteBuffer getSecretBinary(final String secretId, final String versionId, final String versionStage) {
SecretCacheItem secret = this.getCachedSecret(secretId);
GetSecretValueResponse gsv = secret.getSecretValue();
if (null == gsv) {
GetSecretValueResponse gsv = secret.getSecretValue(versionId, versionStage);

if (gsv == null) {
return null;
}

return gsv.secretBinary().asByteBuffer();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

package com.amazonaws.secretsmanager.caching.cache;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
Expand Down Expand Up @@ -115,16 +117,32 @@ protected DescribeSecretResponse executeRefresh() {
* The result of the Describe Secret request to AWS Secrets Manager.
* @return The cached secret version.
*/
private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse) {
private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse, String versionId, String versionStage) {
if (null == describeResponse) { return null; }
if (null == describeResponse.versionIdsToStages()) { return null; }
Optional<String> currentVersionId = describeResponse.versionIdsToStages().entrySet()
.stream()
.filter(Objects::nonNull)
.filter(x -> x.getValue() != null)
.filter(x -> x.getValue().contains(this.config.getVersionStage()))
.map(x -> x.getKey())
.findFirst();

Optional<String> currentVersionId = Optional.empty();

for (Map.Entry<String, List<String>> entry : describeResponse.versionIdsToStages().entrySet()) {
if (entry == null) {
continue;
}

if (entry.getValue() == null) {
continue;
}

if (versionId != null && versionId.equals(entry.getKey())) {
currentVersionId = Optional.of(versionId);
break;
}

if ((versionStage != null && entry.getValue().contains(versionStage)) || entry.getValue().contains(config.getVersionStage())) {
currentVersionId = Optional.of(entry.getKey());
break;
}
}

if (currentVersionId.isPresent()) {
SecretCacheVersion version = versions.get(currentVersionId.get());
if (null == version) {
Expand All @@ -134,6 +152,7 @@ private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse) {
}
return version;
}

return null;
}

Expand All @@ -146,8 +165,26 @@ private SecretCacheVersion getVersion(DescribeSecretResponse describeResponse) {
*/
@Override
protected GetSecretValueResponse getSecretValue(DescribeSecretResponse describeResponse) {
SecretCacheVersion version = getVersion(describeResponse);
if (null == version) { return null; }
return getSecretValue(describeResponse, null, null);
}

/**
* Return the cached GetSecretValue result.
*
* @param describeResponse the DescribeSecret result.
* @param versionId the version ID of the desired secret (optional, can be null).
* @param versionStage the version stage of the desired secret (optional, can be null).
*
* @return The cached GetSecretValue result.
*/
@Override
protected GetSecretValueResponse getSecretValue(DescribeSecretResponse describeResponse, String versionId, String versionStage) {
SecretCacheVersion version = getVersion(describeResponse, versionId, versionStage);

if (version == null) {
return null;
}

return version.getSecretValue();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ public SecretCacheObject(final String secretId,
*/
protected abstract GetSecretValueResponse getSecretValue(T result);

/**
* Execute the actual refresh of the cached secret state.
*
* @param result the GetSecretValue or DescribeSecret result.
* @param versionId the version ID of the desired secret (optional, can be null).
* @param versionStage the version stage of the desired secret (optional, can be null).
*
* @return The cached GetSecretValue result based on the current cached state.
*/
protected abstract GetSecretValueResponse getSecretValue(T result, String versionId, String versionStage);

public abstract boolean equals(Object obj);
public abstract int hashCode();
public abstract String toString();
Expand Down Expand Up @@ -236,13 +247,28 @@ public boolean refreshNow() throws InterruptedException {
* @return The cached GetSecretValue result.
*/
public GetSecretValueResponse getSecretValue() {
return getSecretValue(null, null);
}

/**
* Return the cached GetSecretValue result.
*
* @param versionId the version ID of the desired secret (optional, can be null).
* @param versionStage the version stage of the desired secret (optional, can be null).
*
* @return The cached GetSecretValue result.
*/
public GetSecretValueResponse getSecretValue(String versionId, String versionStage) {
synchronized (lock) {
refresh();
if (null == this.data) {
if (null != this.exception) { throw this.exception; }

if (this.data == null) {
if (this.exception != null) {
throw this.exception;
}
}

return this.getSecretValue(this.getResult());
return this.getSecretValue(this.getResult(), versionId, versionStage);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,18 @@ protected GetSecretValueResponse getSecretValue(GetSecretValueResponse gsvResult
return gsvResult;
}

/**
* Return the cached GetSecretValue result.
*
* @param gsvResult the GetSecretValue or DescribeSecret result.
* @param versionId the version ID of the desired secret (optional, can be null).
* @param versionStage the version stage of the desired secret (optional, can be null).
*
* @return The cached GetSecretValue result.
*/
@Override
protected GetSecretValueResponse getSecretValue(GetSecretValueResponse gsvResult, String versionId, String versionStage) {
return gsvResult;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,61 @@ public void basicSecretCacheTest() {
sc.close();
}

@Test
public void basicSecretCacheVersionIdTest() {
final String secret = "basicSecretCacheTest";
Map<String, List<String>> versionMap = new HashMap<String, List<String>>();
versionMap.put("versionId", Arrays.asList("AWSCURRENT"));
versionMap.put("otherVersionId", Arrays.asList("AWSCURRENT"));
Mockito.when(describeSecretResponse.versionIdsToStages()).thenReturn(versionMap);
GetSecretValueResponse.Builder resBuilder = GetSecretValueResponse.builder().secretString(secret)
.secretBinary(SdkBytes.fromByteArray(secret.getBytes()));
getSecretValueResponse = resBuilder.build();

Mockito.when(asm.describeSecret(Mockito.any(DescribeSecretRequest.class))).thenReturn(describeSecretResponse);
Mockito.when(asm.getSecretValue(Mockito.any(GetSecretValueRequest.class))).thenReturn(getSecretValueResponse);

SecretCache sc = new SecretCache(asm);

// Request the secret multiple times and verify the correct result
repeat(10, n -> Assert.assertEquals(sc.getSecretString("", "otherVersionId", null), secret));

// Verify that multiple requests did not call the API
Mockito.verify(asm, Mockito.times(1)).describeSecret(Mockito.any(DescribeSecretRequest.class));
Mockito.verify(asm, Mockito.times(1)).getSecretValue(Mockito.any(GetSecretValueRequest.class));

repeat(10, n -> Assert.assertEquals(sc.getSecretBinary("", "otherVersionId", null),
ByteBuffer.wrap(secret.getBytes())));
sc.close();
}

@Test
public void basicSecretCacheVersionStageTest() {
final String secret = "basicSecretCacheTest";
Map<String, List<String>> versionMap = new HashMap<String, List<String>>();
versionMap.put("versionId", Arrays.asList("AWSCURRENT"));
Mockito.when(describeSecretResponse.versionIdsToStages()).thenReturn(versionMap);
GetSecretValueResponse.Builder resBuilder = GetSecretValueResponse.builder().secretString(secret)
.secretBinary(SdkBytes.fromByteArray(secret.getBytes()));
getSecretValueResponse = resBuilder.build();

Mockito.when(asm.describeSecret(Mockito.any(DescribeSecretRequest.class))).thenReturn(describeSecretResponse);
Mockito.when(asm.getSecretValue(Mockito.any(GetSecretValueRequest.class))).thenReturn(getSecretValueResponse);

SecretCache sc = new SecretCache(asm);

// Request the secret multiple times and verify the correct result
repeat(10, n -> Assert.assertEquals(sc.getSecretString("", null, "AWSCURRENT"), secret));

// Verify that multiple requests did not call the API
Mockito.verify(asm, Mockito.times(1)).describeSecret(Mockito.any(DescribeSecretRequest.class));
Mockito.verify(asm, Mockito.times(1)).getSecretValue(Mockito.any(GetSecretValueRequest.class));

repeat(10, n -> Assert.assertEquals(sc.getSecretBinary("", null, "AWSCURRENT"),
ByteBuffer.wrap(secret.getBytes())));
sc.close();
}

@Test
public void hookSecretCacheTest() {
final String secret = "hookSecretCacheTest";
Expand Down