diff --git a/.gitignore b/.gitignore index e3889a0..3f988ad 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ build/ nbbuild/ dist/ nbdist/ -.nb-gradle/ \ No newline at end of file +.nb-gradle/ + +.DS_Store diff --git a/build.gradle b/build.gradle index b50b350..5ce7ce7 100644 --- a/build.gradle +++ b/build.gradle @@ -1,39 +1,41 @@ buildscript { ext { - springBootVersion = '1.5.6.RELEASE' coverallsVersion = '2.8.1' } repositories { mavenCentral() } dependencies { - classpath("org.springframework.boot:spring-boot-gradle-plugin:${springBootVersion}") classpath("org.kt3k.gradle.plugin:coveralls-gradle-plugin:${coverallsVersion}") } } apply plugin: 'java' -apply plugin: 'eclipse' -apply plugin: 'org.springframework.boot' apply plugin: 'jacoco' apply plugin: 'com.github.kt3k.coveralls' +apply plugin: 'maven' version = '0.0.1-SNAPSHOT' +group = 'mertz.security' + sourceCompatibility = 1.8 repositories { mavenCentral() + maven { + url 'https://repo.spring.io/libs-release' + } } dependencies { - compile('org.springframework.boot:spring-boot-starter-data-cassandra') - compile('org.springframework.security.oauth:spring-security-oauth2') + compile('org.springframework.data:spring-data-cassandra:2.0.10.RELEASE') + compile('org.springframework.security.oauth:spring-security-oauth2:2.3.3.RELEASE') compile('com.fasterxml.jackson.core:jackson-databind:2.9.0') compile('com.fasterxml.jackson.core:jackson-annotations:2.9.0') compile('com.fasterxml.jackson.core:jackson-core:2.9.0') - testCompile('org.springframework.boot:spring-boot-starter-test') - testCompile('org.springframework.security:spring-security-test') + testCompile('org.springframework.boot:spring-boot-starter-test:2.0.2.RELEASE') + testCompile('org.springframework.security:spring-security-test:5.0.8.RELEASE') testCompile('org.cassandraunit:cassandra-unit-spring:3.1.3.2') } diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 0000000..28e6f86 --- /dev/null +++ b/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'spring-oauth2-cassandra-token-store' \ No newline at end of file diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStore.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStore.java index e4b9c8a..abb4a2a 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStore.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStore.java @@ -3,18 +3,21 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.cassandra.core.WriteOptions; +import org.springframework.data.cassandra.core.CassandraBatchOperations; import org.springframework.data.cassandra.core.CassandraTemplate; -import org.springframework.data.cassandra.mapping.CassandraMappingContext; +import org.springframework.data.cassandra.core.cql.WriteOptions; +import org.springframework.data.cassandra.core.mapping.CassandraMappingContext; import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2RefreshToken; @@ -25,10 +28,6 @@ import org.springframework.stereotype.Component; import com.datastax.driver.core.RegularStatement; -import com.datastax.driver.core.querybuilder.Batch; -import com.datastax.driver.core.querybuilder.Delete; -import com.datastax.driver.core.querybuilder.Insert; -import com.datastax.driver.core.querybuilder.QueryBuilder; import mertz.security.oauth2.provider.token.store.cassandra.cfg.OAuthUtil; import mertz.security.oauth2.provider.token.store.cassandra.model.AccessToken; @@ -53,38 +52,48 @@ public class CassandraTokenStore implements TokenStore { private static final Logger logger = LoggerFactory.getLogger(CassandraTokenStore.class); - @Autowired - private AuthenticationRepository authenticationRepository; + private final AuthenticationRepository authenticationRepository; - @Autowired - private AccessTokenRepository accessTokenRepository; + private final AccessTokenRepository accessTokenRepository; - @Autowired - private RefreshTokenRepository refreshTokenRepository; + private final RefreshTokenRepository refreshTokenRepository; - @Autowired - private RefreshTokenAuthenticationRepository refreshTokenAuthenticationRepository; + private final RefreshTokenAuthenticationRepository refreshTokenAuthenticationRepository; - @Autowired - private AuthenticationToAccessTokenRepository authenticationToAccessTokenRepository; + private final AuthenticationToAccessTokenRepository authenticationToAccessTokenRepository; - @Autowired - private UsernameToAccessTokenRepository usernameToAccessTokenRepository; + private final UsernameToAccessTokenRepository usernameToAccessTokenRepository; - @Autowired - private ClientIdToAccessTokenRepository clientIdToAccessTokenRepository; + private final ClientIdToAccessTokenRepository clientIdToAccessTokenRepository; - @Autowired - private RefreshTokenToAccessTokenRepository refreshTokenToAccessTokenRepository; + private final RefreshTokenToAccessTokenRepository refreshTokenToAccessTokenRepository; - @Autowired - private CassandraTemplate cassandraTemplate; + private final CassandraTemplate cassandraTemplate; - @Autowired - private CassandraMappingContext cassandraMappingContext; + private final AuthenticationKeyGenerator authenticationKeyGenerator; @Autowired - private AuthenticationKeyGenerator authenticationKeyGenerator; + public CassandraTokenStore(AuthenticationRepository authenticationRepository, + AccessTokenRepository accessTokenRepository, + RefreshTokenRepository refreshTokenRepository, + RefreshTokenAuthenticationRepository refreshTokenAuthenticationRepository, + AuthenticationToAccessTokenRepository authenticationToAccessTokenRepository, + UsernameToAccessTokenRepository usernameToAccessTokenRepository, + ClientIdToAccessTokenRepository clientIdToAccessTokenRepository, + RefreshTokenToAccessTokenRepository refreshTokenToAccessTokenRepository, + CassandraTemplate cassandraTemplate, + AuthenticationKeyGenerator authenticationKeyGenerator) { + this.authenticationRepository = authenticationRepository; + this.accessTokenRepository = accessTokenRepository; + this.refreshTokenRepository = refreshTokenRepository; + this.refreshTokenAuthenticationRepository = refreshTokenAuthenticationRepository; + this.authenticationToAccessTokenRepository = authenticationToAccessTokenRepository; + this.usernameToAccessTokenRepository = usernameToAccessTokenRepository; + this.clientIdToAccessTokenRepository = clientIdToAccessTokenRepository; + this.refreshTokenToAccessTokenRepository = refreshTokenToAccessTokenRepository; + this.cassandraTemplate = cassandraTemplate; + this.authenticationKeyGenerator = authenticationKeyGenerator; + } @Override public OAuth2Authentication readAuthentication(OAuth2AccessToken token) { @@ -93,130 +102,94 @@ public OAuth2Authentication readAuthentication(OAuth2AccessToken token) { @Override public OAuth2Authentication readAuthentication(String token) { - Authentication authentication = authenticationRepository.findOne(token); - if (authentication != null) { - ByteBuffer bufferedOAuth2Authentication = authentication.getoAuth2Authentication(); - byte[] serializedOAuth2Authentication = new byte[bufferedOAuth2Authentication.remaining()]; - bufferedOAuth2Authentication.get(serializedOAuth2Authentication); - OAuth2Authentication oAuth2Authentication = SerializationUtils.deserialize(serializedOAuth2Authentication); - return oAuth2Authentication; - } else { - return null; - } + return authenticationRepository + .findById(token) + .map(authentication -> + deserializeOAuth2Authentication(authentication.getoAuth2Authentication())) + .orElse(null); } @Override public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) { - List statementList = new ArrayList(); String jsonAccessToken = OAuthUtil.serializeOAuth2AccessToken(token); byte[] serializedOAuth2Authentication = SerializationUtils.serialize(authentication); ByteBuffer bufferedOAuth2Authentication = ByteBuffer.wrap(serializedOAuth2Authentication); - WriteOptions accessWriteOptions = new WriteOptions(); + WriteOptions.WriteOptionsBuilder accessWriteOptionsBuilder = WriteOptions.builder(); if (token.getExpiration() != null) { int seconds = token.getExpiresIn(); - accessWriteOptions.setTtl(seconds); + accessWriteOptionsBuilder.ttl(seconds); } - - // Insert into AccessToken table - Insert accessInsert = CassandraTemplate.createInsertQuery(AccessToken.TABLE, new AccessToken(token.getValue(), jsonAccessToken), accessWriteOptions, cassandraTemplate.getConverter()); - statementList.add(accessInsert); - - // Insert into Authentication table - Insert authInsert = CassandraTemplate.createInsertQuery(Authentication.TABLE, new Authentication(token.getValue(), bufferedOAuth2Authentication), accessWriteOptions, cassandraTemplate.getConverter()); - statementList.add(authInsert); - - // Insert into AuthenticationToAccessToken table - Insert authToAccessInsert = CassandraTemplate.createInsertQuery(AuthenticationToAccessToken.TABLE, new AuthenticationToAccessToken(authenticationKeyGenerator.extractKey(authentication), jsonAccessToken), accessWriteOptions, cassandraTemplate.getConverter()); - statementList.add(authToAccessInsert); - - // Insert into UsernameToAccessToken table - Insert unameToAccessInsert = CassandraTemplate.createInsertQuery(UsernameToAccessToken.TABLE, new UsernameToAccessToken(OAuthUtil.getApprovalKey(authentication), jsonAccessToken), accessWriteOptions, cassandraTemplate.getConverter()); - statementList.add(unameToAccessInsert); - - // Insert into ClientIdToAccessToken table - Insert clientIdToAccessInsert = CassandraTemplate.createInsertQuery(ClientIdToAccessToken.TABLE, new ClientIdToAccessToken(authentication.getOAuth2Request().getClientId(), jsonAccessToken), accessWriteOptions, cassandraTemplate.getConverter()); - statementList.add(clientIdToAccessInsert); + WriteOptions accessWriteOptions = accessWriteOptionsBuilder.build(); + + CassandraBatchOperations batch = cassandraTemplate.batchOps() + .insert(Collections.singleton(new AccessToken(token.getValue(), jsonAccessToken)), accessWriteOptions) + .insert(Collections.singleton(new Authentication(token.getValue(), bufferedOAuth2Authentication)), + accessWriteOptions) + .insert(Collections.singleton( + new AuthenticationToAccessToken(authenticationKeyGenerator.extractKey(authentication), jsonAccessToken)), + accessWriteOptions) + .insert(Collections.singleton( + new UsernameToAccessToken(OAuthUtil.getApprovalKey(authentication), jsonAccessToken)), + accessWriteOptions) + .insert(Collections.singleton( + new ClientIdToAccessToken(authentication.getOAuth2Request().getClientId(), jsonAccessToken)), + accessWriteOptions); OAuth2RefreshToken oAuth2RefreshToken = token.getRefreshToken(); if (oAuth2RefreshToken != null && oAuth2RefreshToken.getValue() != null) { - WriteOptions refreshWriteOptions = new WriteOptions(); - if (oAuth2RefreshToken instanceof ExpiringOAuth2RefreshToken) { - ExpiringOAuth2RefreshToken expiringRefreshToken = (ExpiringOAuth2RefreshToken) oAuth2RefreshToken; - Date expiration = expiringRefreshToken.getExpiration(); - if (expiration != null) { - int seconds = Long.valueOf((expiration.getTime() - System.currentTimeMillis()) / 1000L).intValue(); - refreshWriteOptions.setTtl(seconds); - } - } - // Insert into RefreshTokenToAccessToken table - Insert refreshTokenToAccessTokenInsert = CassandraTemplate.createInsertQuery(RefreshTokenToAccessToken.TABLE, new RefreshTokenToAccessToken(token.getRefreshToken().getValue(), token.getValue()), refreshWriteOptions, cassandraTemplate.getConverter()); - statementList.add(refreshTokenToAccessTokenInsert); - } + WriteOptions refreshWriteOptions = buildRefreshTokenWriteOptions(oAuth2RefreshToken); - Batch batch = QueryBuilder.batch(statementList.toArray(new RegularStatement[statementList.size()])); - cassandraTemplate.execute(batch); + batch = batch.insert( + Collections.singleton(new RefreshTokenToAccessToken(token.getRefreshToken().getValue(), token.getValue())), + refreshWriteOptions); + } + batch.execute(); } @Override public OAuth2AccessToken readAccessToken(String tokenValue) { - AccessToken accessToken = accessTokenRepository.findOne(tokenValue); - if (accessToken != null) { - return OAuthUtil.deserializeOAuth2AccessToken(accessToken.getoAuth2AccessToken()); - } else { - return null; - } + return accessTokenRepository.findById(tokenValue).map( accessToken -> + OAuthUtil.deserializeOAuth2AccessToken(accessToken.getoAuth2AccessToken()) + ).orElse(null); } @Override public void removeAccessToken(OAuth2AccessToken token) { - List statementList = prepareRemoveAccessTokenStatements(token); - Batch batch = QueryBuilder.batch(statementList.toArray(new RegularStatement[statementList.size()])); - cassandraTemplate.execute(batch); + prepareRemoveAccessTokenStatements(token).execute(); } - private List prepareRemoveAccessTokenStatements(OAuth2AccessToken token) { - //String tokenId = token.getValue(); + private CassandraBatchOperations prepareRemoveAccessTokenStatements(OAuth2AccessToken token) { + CassandraBatchOperations batch = cassandraTemplate.batchOps(); String tokenValue = token.getValue(); String jsonOAuth2AccessToken = OAuthUtil.serializeOAuth2AccessToken(token); - List statementList = new ArrayList(); - - // Delete from AccessToken table - RegularStatement accessTokenDelete = prepareDeleteByPrimaryKeyRegularStatement(AccessToken.class, tokenValue); - statementList.add(accessTokenDelete); + batch.delete(new AccessToken(tokenValue,null)); // Lookup Authentication table for further deleting from AuthenticationToAccessToken table - Authentication authentication = authenticationRepository.findOne(tokenValue); - if (authentication != null) { + authenticationRepository.findById(tokenValue).ifPresent(authentication -> { ByteBuffer bufferedOAuth2Authentication = authentication.getoAuth2Authentication(); byte[] serializedOAuth2Authentication = new byte[bufferedOAuth2Authentication.remaining()]; bufferedOAuth2Authentication.get(serializedOAuth2Authentication); OAuth2Authentication oAuth2Authentication = SerializationUtils.deserialize(serializedOAuth2Authentication); String clientId = oAuth2Authentication.getOAuth2Request().getClientId(); - // Delete from Authentication table - RegularStatement authenticationDelete = prepareDeleteByPrimaryKeyRegularStatement(Authentication.class, tokenValue); - statementList.add(authenticationDelete); - - // Delete from AuthenticationToAccessToken table - RegularStatement authToAccessDelete = prepareDeleteByPrimaryKeyRegularStatement(AuthenticationToAccessToken.class, authenticationKeyGenerator.extractKey(oAuth2Authentication)); - statementList.add(authToAccessDelete); + batch.delete(authentication); + batch.delete(new AuthenticationToAccessToken( + authenticationKeyGenerator.extractKey(oAuth2Authentication),null)); // Delete from UsernameToAccessToken table - Optional optionalUsernameToAccessToken = usernameToAccessTokenRepository.findByKeyAndOAuth2AccessToken(OAuthUtil.getApprovalKey(clientId, oAuth2Authentication.getName()), jsonOAuth2AccessToken); - optionalUsernameToAccessToken.ifPresent(usernameToAccessToken -> { - Delete usernameToAccessDelete = CassandraTemplate.createDeleteQuery(UsernameToAccessToken.TABLE, usernameToAccessToken, null, cassandraTemplate.getConverter()); - statementList.add(usernameToAccessDelete); - }); + usernameToAccessTokenRepository + .findByKeyAndOAuth2AccessToken( + OAuthUtil.getApprovalKey(clientId, oAuth2Authentication.getName()), + jsonOAuth2AccessToken) + .ifPresent(batch::delete); // Delete from ClientIdToAccessToken table - Optional optionalClientIdToAccessToken = clientIdToAccessTokenRepository.findByKeyAndOAuth2AccessToken(clientId, jsonOAuth2AccessToken); - optionalClientIdToAccessToken.ifPresent(clientIdToAccessToken -> { - Delete clientIdToAccessDelete = CassandraTemplate.createDeleteQuery(ClientIdToAccessToken.TABLE, clientIdToAccessToken, null, cassandraTemplate.getConverter()); - statementList.add(clientIdToAccessDelete); - }); - } + clientIdToAccessTokenRepository + .findByKeyAndOAuth2AccessToken(clientId, jsonOAuth2AccessToken) + .ifPresent(batch::delete); + }); - return statementList; + return batch; } @Override @@ -229,139 +202,113 @@ public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authenticat byte[] serializedAuthentication = SerializationUtils.serialize(authentication); ByteBuffer bufferedAuthentication = ByteBuffer.wrap(serializedAuthentication); - WriteOptions refreshWriteOptions = new WriteOptions(); + WriteOptions refreshWriteOptions = buildRefreshTokenWriteOptions(refreshToken); + + cassandraTemplate.batchOps() + .insert(Collections.singleton(new RefreshToken(refreshToken.getValue(), bufferedRefreshToken)), + refreshWriteOptions) + .insert(Collections.singleton(new RefreshTokenAuthentication(refreshToken.getValue(), bufferedAuthentication)), + refreshWriteOptions) + .execute(); + } + + private WriteOptions buildRefreshTokenWriteOptions(OAuth2RefreshToken refreshToken) { + WriteOptions.WriteOptionsBuilder refreshWriteOptionsBuilder = WriteOptions.builder(); if (refreshToken instanceof ExpiringOAuth2RefreshToken) { ExpiringOAuth2RefreshToken expiringRefreshToken = (ExpiringOAuth2RefreshToken) refreshToken; Date expiration = expiringRefreshToken.getExpiration(); if (expiration != null) { int seconds = Long.valueOf((expiration.getTime() - System.currentTimeMillis()) / 1000L).intValue(); - refreshWriteOptions.setTtl(seconds); + refreshWriteOptionsBuilder.ttl(seconds); } } - - // Insert into RefreshToken table - Insert accessInsert = CassandraTemplate.createInsertQuery(RefreshToken.TABLE, new RefreshToken(refreshToken.getValue(), bufferedRefreshToken), refreshWriteOptions, cassandraTemplate.getConverter()); - statementList.add(accessInsert); - - // Insert into RefreshTokenAuthentication table - Insert authInsert = CassandraTemplate.createInsertQuery(RefreshTokenAuthentication.TABLE, new RefreshTokenAuthentication(refreshToken.getValue(), bufferedAuthentication), refreshWriteOptions, cassandraTemplate.getConverter()); - statementList.add(authInsert); - - Batch batch = QueryBuilder.batch(statementList.toArray(new RegularStatement[statementList.size()])); - cassandraTemplate.execute(batch); + return refreshWriteOptionsBuilder.build(); } @Override public OAuth2RefreshToken readRefreshToken(String tokenValue) { - RefreshToken refreshToken = refreshTokenRepository.findOne(tokenValue); - if (refreshToken != null) { + return refreshTokenRepository.findById(tokenValue).map(refreshToken -> { ByteBuffer bufferedRefreshToken = refreshToken.getoAuth2RefreshToken(); byte[] serializedRefreshToken = new byte[bufferedRefreshToken.remaining()]; bufferedRefreshToken.get(serializedRefreshToken); - OAuth2RefreshToken oAuth2RefreshToken = SerializationUtils.deserialize(serializedRefreshToken); - return oAuth2RefreshToken; - } else { - return null; - } + return SerializationUtils.deserialize(serializedRefreshToken); + }).orElse(null); } @Override public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) { - RefreshTokenAuthentication refreshTokenAuthentication = refreshTokenAuthenticationRepository.findOne(token.getValue()); - if (refreshTokenAuthentication != null) { - ByteBuffer bufferedOAuth2Authentication = refreshTokenAuthentication.getoAuth2Authentication(); - byte[] serializedOAuth2Authentication = new byte[bufferedOAuth2Authentication.remaining()]; - bufferedOAuth2Authentication.get(serializedOAuth2Authentication); - OAuth2Authentication oAuth2Authentication = SerializationUtils.deserialize(serializedOAuth2Authentication); - return oAuth2Authentication; - } else { - return null; - } + return refreshTokenAuthenticationRepository.findById(token.getValue()).map(refreshTokenAuthentication -> + deserializeOAuth2Authentication(refreshTokenAuthentication.getoAuth2Authentication())) + .orElse(null); + } + + private OAuth2Authentication deserializeOAuth2Authentication(ByteBuffer byteBuffer) { + byte[] serializedOAuth2Authentication = new byte[byteBuffer.remaining()]; + byteBuffer.get(serializedOAuth2Authentication); + return SerializationUtils.deserialize(serializedOAuth2Authentication); } @Override public void removeRefreshToken(OAuth2RefreshToken token) { String tokenValue = token.getValue(); - List statementList = new ArrayList(); - // Delete from RefreshToken table - statementList.add(prepareDeleteByPrimaryKeyRegularStatement(RefreshToken.class, tokenValue)); - // Delete from RefreshTokenAuthentication table - statementList.add(prepareDeleteByPrimaryKeyRegularStatement(RefreshTokenAuthentication.class, tokenValue)); - // Delete from RefreshTokenToAccessToken table - statementList.add(prepareDeleteByPrimaryKeyRegularStatement(RefreshTokenToAccessToken.class, tokenValue)); - Batch batch = QueryBuilder.batch(statementList.toArray(new RegularStatement[statementList.size()])); - cassandraTemplate.execute(batch); - } - - private RegularStatement prepareDeleteByPrimaryKeyRegularStatement(Class repositoryClass, String primaryKeyValue) { - RegularStatement deleteRegularStatement; - try { - deleteRegularStatement = QueryBuilder.delete().from(repositoryClass.getDeclaredField("TABLE").get(null).toString()).where(QueryBuilder.eq(cassandraMappingContext.getPersistentEntity(repositoryClass).getIdProperty().getColumnName().toCql(), primaryKeyValue)); - } catch (IllegalArgumentException | IllegalAccessException | NoSuchFieldException | SecurityException e) { - logger.error("Error preparing delete statement for repository {}.", repositoryClass.getSimpleName()); - throw new RuntimeException(e); - } - return deleteRegularStatement; + cassandraTemplate.batchOps() + .delete(new RefreshToken(tokenValue,null)) + .delete(new RefreshTokenAuthentication(tokenValue,null)) + .delete(new RefreshTokenToAccessToken(tokenValue, null)) + .execute(); } @Override public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) { String tokenValue = refreshToken.getValue(); // Lookup RefreshTokenToAccessToken table for locating access token - RefreshTokenToAccessToken refreshTokenToAccessToken = refreshTokenToAccessTokenRepository.findOne(tokenValue); - if (refreshTokenToAccessToken != null) { + refreshTokenToAccessTokenRepository.findById(tokenValue).ifPresent(refreshTokenToAccessToken -> { String accessTokenKey = refreshTokenToAccessToken.getAccessTokenKey(); - AccessToken accessToken = accessTokenRepository.findOne(accessTokenKey); - if (accessToken == null) { - // access token removed already or expired. - return; - } - String jsonOAuth2AccessToken = accessToken.getoAuth2AccessToken(); - OAuth2AccessToken oAuth2AccessToken = OAuthUtil.deserializeOAuth2AccessToken(jsonOAuth2AccessToken); - // Delete access token from all related tables - List statementList = prepareRemoveAccessTokenStatements(oAuth2AccessToken); - // Delete from RefreshTokenToAccessToken table - Delete refreshTokenToAccessTokenDelete = CassandraTemplate.createDeleteQuery(RefreshTokenToAccessToken.TABLE, refreshTokenToAccessToken, null, cassandraTemplate.getConverter()); - statementList.add(refreshTokenToAccessTokenDelete); - Batch batch = QueryBuilder.batch(statementList.toArray(new RegularStatement[statementList.size()])); - cassandraTemplate.execute(batch); - } + accessTokenRepository.findById(accessTokenKey).ifPresent(accessToken -> { + + String jsonOAuth2AccessToken = accessToken.getoAuth2AccessToken(); + OAuth2AccessToken oAuth2AccessToken = OAuthUtil.deserializeOAuth2AccessToken(jsonOAuth2AccessToken); + // Delete access token from all related tables + CassandraBatchOperations batch = prepareRemoveAccessTokenStatements(oAuth2AccessToken); + // Delete from RefreshTokenToAccessToken table + batch = batch.delete(refreshTokenToAccessToken); + batch.execute(); + }); + }); } @Override public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) { String key = authenticationKeyGenerator.extractKey(authentication); - AuthenticationToAccessToken authenticationToAccessToken = authenticationToAccessTokenRepository.findOne(key); - if (authenticationToAccessToken != null) { + return authenticationToAccessTokenRepository.findById(key).map(authenticationToAccessToken -> { OAuth2AccessToken oAuth2AccessToken = OAuthUtil.deserializeOAuth2AccessToken(authenticationToAccessToken.getoAuth2AccessToken()); if (oAuth2AccessToken != null && !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(oAuth2AccessToken.getValue())))) { storeAccessToken(oAuth2AccessToken, authentication); } return oAuth2AccessToken; - } else { - return null; - } + }).orElse(null); } @Override public Collection findTokensByClientIdAndUserName(String clientId, String userName) { - String key = OAuthUtil.getApprovalKey(clientId, userName); - Optional> optionalUsernameToAccessTokenSet = usernameToAccessTokenRepository.findByKey(key); - Set oAuth2AccessTokenSet = new HashSet(); - optionalUsernameToAccessTokenSet.ifPresent(usernameToAccessTokenSet -> { - usernameToAccessTokenSet.forEach(usernameToAccessToken -> oAuth2AccessTokenSet.add(OAuthUtil.deserializeOAuth2AccessToken(usernameToAccessToken.getOAuth2AccessToken()))); - }); - return oAuth2AccessTokenSet; + return usernameToAccessTokenRepository.findByKey(OAuthUtil.getApprovalKey(clientId, userName)) + .map(usernameToAccessTokens -> + usernameToAccessTokens.stream() + .map(usernameToAccessToken -> + OAuthUtil.deserializeOAuth2AccessToken(usernameToAccessToken.getOAuth2AccessToken())) + .collect(Collectors.toSet())) + .orElse(Collections.emptySet()); } @Override public Collection findTokensByClientId(String clientId) { - Optional> optionalClientIdToAccessTokenSet = clientIdToAccessTokenRepository.findByKey(clientId); - Set oAuth2AccessTokenSet = new HashSet(); - optionalClientIdToAccessTokenSet.ifPresent(clientIdToAccessTokenSet -> { - clientIdToAccessTokenSet.forEach(clientIdToAccessToken -> oAuth2AccessTokenSet.add(OAuthUtil.deserializeOAuth2AccessToken(clientIdToAccessToken.getOAuth2AccessToken()))); - }); - return oAuth2AccessTokenSet; + return clientIdToAccessTokenRepository.findByKey(clientId) + .map(clientIdToAccessTokens -> + clientIdToAccessTokens.stream() + .map(clientIdToAccessToken -> + OAuthUtil.deserializeOAuth2AccessToken(clientIdToAccessToken.getOAuth2AccessToken())) + .collect(Collectors.toSet())) + .orElse(Collections.emptySet()); } } diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/OAuthUtil.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/OAuthUtil.java index e1b3b9a..c1b70b4 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/OAuthUtil.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/OAuthUtil.java @@ -27,11 +27,6 @@ public AuthenticationKeyGenerator getAuthenticationKeyGenerator() { return new DefaultAuthenticationKeyGenerator(); } - @Bean - public ObjectMapper getObjectMapper() { - return new ObjectMapper(); - } - public static OAuth2AccessToken deserializeOAuth2AccessToken(String jsonOAuth2AccessToken) { try { return OAUTH2ACCESSTOKEN_OBJECT_READER.readValue(jsonOAuth2AccessToken); diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AccessToken.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AccessToken.java index 61061b5..5adc08a 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AccessToken.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AccessToken.java @@ -1,7 +1,7 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import org.springframework.data.cassandra.mapping.PrimaryKey; -import org.springframework.data.cassandra.mapping.Table; +import org.springframework.data.cassandra.core.mapping.PrimaryKey; +import org.springframework.data.cassandra.core.mapping.Table; @Table(value = AccessToken.TABLE) public class AccessToken { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/Authentication.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/Authentication.java index 3ad744d..1d9085b 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/Authentication.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/Authentication.java @@ -1,9 +1,9 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import java.nio.ByteBuffer; +import org.springframework.data.cassandra.core.mapping.PrimaryKey; +import org.springframework.data.cassandra.core.mapping.Table; -import org.springframework.data.cassandra.mapping.PrimaryKey; -import org.springframework.data.cassandra.mapping.Table; +import java.nio.ByteBuffer; @Table(value = Authentication.TABLE) public class Authentication { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AuthenticationToAccessToken.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AuthenticationToAccessToken.java index 03dfa4b..a5514c6 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AuthenticationToAccessToken.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/AuthenticationToAccessToken.java @@ -1,7 +1,8 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import org.springframework.data.cassandra.mapping.PrimaryKey; -import org.springframework.data.cassandra.mapping.Table; + +import org.springframework.data.cassandra.core.mapping.PrimaryKey; +import org.springframework.data.cassandra.core.mapping.Table; @Table(value = AuthenticationToAccessToken.TABLE) public class AuthenticationToAccessToken { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/ClientIdToAccessToken.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/ClientIdToAccessToken.java index 914631d..1114c41 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/ClientIdToAccessToken.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/ClientIdToAccessToken.java @@ -1,8 +1,8 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import org.springframework.cassandra.core.PrimaryKeyType; -import org.springframework.data.cassandra.mapping.PrimaryKeyColumn; -import org.springframework.data.cassandra.mapping.Table; +import org.springframework.data.cassandra.core.cql.PrimaryKeyType; +import org.springframework.data.cassandra.core.mapping.PrimaryKeyColumn; +import org.springframework.data.cassandra.core.mapping.Table; @Table(value = ClientIdToAccessToken.TABLE) public class ClientIdToAccessToken { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshToken.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshToken.java index 6575ca0..2d5fc00 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshToken.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshToken.java @@ -1,9 +1,9 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import java.nio.ByteBuffer; +import org.springframework.data.cassandra.core.mapping.PrimaryKey; +import org.springframework.data.cassandra.core.mapping.Table; -import org.springframework.data.cassandra.mapping.PrimaryKey; -import org.springframework.data.cassandra.mapping.Table; +import java.nio.ByteBuffer; @Table(value = RefreshToken.TABLE) public class RefreshToken { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenAuthentication.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenAuthentication.java index 52c97bd..5982674 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenAuthentication.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenAuthentication.java @@ -1,9 +1,9 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import java.nio.ByteBuffer; +import org.springframework.data.cassandra.core.mapping.PrimaryKey; +import org.springframework.data.cassandra.core.mapping.Table; -import org.springframework.data.cassandra.mapping.PrimaryKey; -import org.springframework.data.cassandra.mapping.Table; +import java.nio.ByteBuffer; @Table(value = RefreshTokenAuthentication.TABLE) public class RefreshTokenAuthentication { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenToAccessToken.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenToAccessToken.java index 4bd014a..0a3f9bd 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenToAccessToken.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/RefreshTokenToAccessToken.java @@ -1,7 +1,8 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import org.springframework.data.cassandra.mapping.PrimaryKey; -import org.springframework.data.cassandra.mapping.Table; + +import org.springframework.data.cassandra.core.mapping.PrimaryKey; +import org.springframework.data.cassandra.core.mapping.Table; @Table(value = RefreshTokenToAccessToken.TABLE) public class RefreshTokenToAccessToken { diff --git a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/UsernameToAccessToken.java b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/UsernameToAccessToken.java index e3a6156..244eccd 100644 --- a/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/UsernameToAccessToken.java +++ b/src/main/java/mertz/security/oauth2/provider/token/store/cassandra/model/UsernameToAccessToken.java @@ -1,8 +1,9 @@ package mertz.security.oauth2.provider.token.store.cassandra.model; -import org.springframework.cassandra.core.PrimaryKeyType; -import org.springframework.data.cassandra.mapping.PrimaryKeyColumn; -import org.springframework.data.cassandra.mapping.Table; + +import org.springframework.data.cassandra.core.cql.PrimaryKeyType; +import org.springframework.data.cassandra.core.mapping.PrimaryKeyColumn; +import org.springframework.data.cassandra.core.mapping.Table; @Table(value = UsernameToAccessToken.TABLE) public class UsernameToAccessToken { diff --git a/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStoreTests.java b/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStoreTests.java index f1a33a1..ae8b078 100644 --- a/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStoreTests.java +++ b/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/CassandraTokenStoreTests.java @@ -6,15 +6,11 @@ import java.util.Date; import java.util.UUID; -import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.ConfigFileApplicationContextInitializer; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; -import org.springframework.data.cassandra.core.CassandraOperations; -import org.springframework.data.cassandra.mapping.CassandraMappingContext; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.common.DefaultExpiringOAuth2RefreshToken; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; @@ -26,20 +22,10 @@ import org.springframework.security.oauth2.provider.RequestTokenFactory; import org.springframework.security.oauth2.provider.token.TokenStore; import org.springframework.security.oauth2.provider.token.store.TokenStoreBaseTests; -import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.ContextConfiguration; -import org.springframework.test.context.junit4.SpringRunner; -@RunWith(SpringRunner.class) @ContextConfiguration(initializers = ConfigFileApplicationContextInitializer.class) -@ActiveProfiles(profiles = "externalcassandra") -public class CassandraTokenStoreTests extends TokenStoreBaseTests { - - @Autowired - private CassandraOperations cassandraOperations; - - @Autowired - private CassandraMappingContext cassandraMappingContext; +public abstract class CassandraTokenStoreTests extends TokenStoreBaseTests { @Autowired private TokenStore cassandraTokenStore; @@ -49,11 +35,6 @@ public TokenStore getTokenStore() { return cassandraTokenStore; } - @Before - public void setUp() throws Exception { - cassandraMappingContext.getTableEntities().forEach(entity -> cassandraOperations.truncate(entity.getTableName())); - } - @Configuration @ComponentScan(basePackages = "mertz.security.oauth2.provider.token.store.cassandra") public static class SpringConfig { diff --git a/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/EmbeddedCassandraTokenStoreTests.java b/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/EmbeddedCassandraTokenStoreTests.java index 7c30129..3b9c77f 100644 --- a/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/EmbeddedCassandraTokenStoreTests.java +++ b/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/EmbeddedCassandraTokenStoreTests.java @@ -4,7 +4,9 @@ import org.cassandraunit.spring.CassandraUnitDependencyInjectionTestExecutionListener; import org.cassandraunit.spring.EmbeddedCassandra; import org.junit.runner.RunWith; +import org.springframework.boot.test.context.ConfigFileApplicationContextInitializer; import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.TestExecutionListeners; import org.springframework.test.context.TestExecutionListeners.MergeMode; import org.springframework.test.context.junit4.SpringRunner; @@ -18,5 +20,4 @@ @CassandraDataSet(keyspace = "${spring.data.cassandra.keyspace-name}") @ActiveProfiles(profiles = "embeddedcassandra", inheritProfiles = false) public class EmbeddedCassandraTokenStoreTests extends CassandraTokenStoreTests { - } diff --git a/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/CassandraConfig.java b/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/CassandraConfig.java index 602df53..4d0fdce 100644 --- a/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/CassandraConfig.java +++ b/src/test/java/mertz/security/oauth2/provider/token/store/cassandra/cfg/CassandraConfig.java @@ -4,10 +4,10 @@ import java.util.List; import org.springframework.beans.factory.annotation.Value; -import org.springframework.cassandra.core.keyspace.CreateKeyspaceSpecification; import org.springframework.context.annotation.Configuration; +import org.springframework.data.cassandra.config.AbstractCassandraConfiguration; import org.springframework.data.cassandra.config.SchemaAction; -import org.springframework.data.cassandra.config.java.AbstractCassandraConfiguration; +import org.springframework.data.cassandra.core.cql.keyspace.CreateKeyspaceSpecification; import org.springframework.data.cassandra.repository.config.EnableCassandraRepositories; @Configuration @@ -28,7 +28,7 @@ protected List getKeyspaceCreations() { @Override public SchemaAction getSchemaAction() { - return SchemaAction.CREATE_IF_NOT_EXISTS; + return SchemaAction.RECREATE; } @Override diff --git a/src/test/java/org/springframework/security/oauth2/provider/token/store/TokenStoreBaseTests.java b/src/test/java/org/springframework/security/oauth2/provider/token/store/TokenStoreBaseTests.java index 26f8e8f..3df793c 100644 --- a/src/test/java/org/springframework/security/oauth2/provider/token/store/TokenStoreBaseTests.java +++ b/src/test/java/org/springframework/security/oauth2/provider/token/store/TokenStoreBaseTests.java @@ -88,7 +88,7 @@ public void testRetrieveAccessToken() { assertEquals(expectedOAuth2AccessToken, actualOAuth2AccessToken); assertEquals(authentication.getUserAuthentication(), getTokenStore().readAuthentication(expectedOAuth2AccessToken.getValue()).getUserAuthentication()); // The authorizationRequest does not match because it is unapproved, but the token was granted to an approved request - assertFalse(storedOAuth2Request.equals(getTokenStore().readAuthentication(expectedOAuth2AccessToken.getValue()).getOAuth2Request())); + assertNotEquals(storedOAuth2Request, getTokenStore().readAuthentication(expectedOAuth2AccessToken.getValue()).getOAuth2Request()); actualOAuth2AccessToken = getTokenStore().getAccessToken(authentication); assertEquals(expectedOAuth2AccessToken, actualOAuth2AccessToken); getTokenStore().removeAccessToken(expectedOAuth2AccessToken); @@ -196,12 +196,13 @@ public void testRemoveRefreshToken() { @Test public void testRemovedTokenCannotBeFoundByUsername() { + String clientId = "id"+UUID.randomUUID(); OAuth2AccessToken token = new DefaultOAuth2AccessToken("testToken"); OAuth2Authentication expectedAuthentication = new OAuth2Authentication(RequestTokenFactory.createOAuth2Request( - "id", false), new TestAuthentication("test2", false)); + clientId, false), new TestAuthentication("test2", false)); getTokenStore().storeAccessToken(token, expectedAuthentication); getTokenStore().removeAccessToken(token); - Collection tokens = getTokenStore().findTokensByClientIdAndUserName("id", "test2"); + Collection tokens = getTokenStore().findTokensByClientIdAndUserName(clientId, "test2"); assertFalse(tokens.contains(token)); assertTrue(tokens.isEmpty()); }