2323import jakarta .annotation .Nonnull ;
2424import java .net .URI ;
2525import java .util .HashMap ;
26+ import java .util .List ;
2627import java .util .Map ;
2728import java .util .Optional ;
2829import java .util .Set ;
3334import org .apache .polaris .core .storage .StorageAccessProperty ;
3435import org .apache .polaris .core .storage .StorageUtil ;
3536import org .apache .polaris .core .storage .aws .StsClientProvider .StsDestination ;
37+ import org .slf4j .Logger ;
38+ import org .slf4j .LoggerFactory ;
3639import software .amazon .awssdk .auth .credentials .AwsCredentialsProvider ;
3740import software .amazon .awssdk .policybuilder .iam .IamConditionOperator ;
3841import software .amazon .awssdk .policybuilder .iam .IamEffect ;
@@ -49,6 +52,9 @@ public class AwsCredentialsStorageIntegration
4952 private final StsClientProvider stsClientProvider ;
5053 private final Optional <AwsCredentialsProvider > credentialsProvider ;
5154
55+ private static final Logger LOGGER =
56+ LoggerFactory .getLogger (AwsCredentialsStorageIntegration .class );
57+
5258 public AwsCredentialsStorageIntegration (
5359 AwsStorageConfigurationInfo config , StsClient fixedClient ) {
5460 this (config , (destination ) -> fixedClient );
@@ -80,6 +86,7 @@ public StorageAccessConfig getSubscopedCreds(
8086 realmConfig .getConfig (STORAGE_CREDENTIAL_DURATION_SECONDS );
8187 AwsStorageConfigurationInfo storageConfig = config ();
8288 String region = storageConfig .getRegion ();
89+ String accountId = storageConfig .getAwsAccountId ();
8390 StorageAccessConfig .Builder accessConfig = StorageAccessConfig .builder ();
8491
8592 if (shouldUseSts (storageConfig )) {
@@ -90,10 +97,12 @@ public StorageAccessConfig getSubscopedCreds(
9097 .roleSessionName ("PolarisAwsCredentialsStorageIntegration" )
9198 .policy (
9299 policyString (
93- storageConfig . getAwsPartition () ,
100+ storageConfig ,
94101 allowListOperation ,
95102 allowedReadLocations ,
96- allowedWriteLocations )
103+ allowedWriteLocations ,
104+ region ,
105+ accountId )
97106 .toJson ())
98107 .durationSeconds (storageCredentialDurationSeconds );
99108 credentialsProvider .ifPresent (
@@ -163,12 +172,13 @@ private boolean shouldUseSts(AwsStorageConfigurationInfo storageConfig) {
163172 * ListBucket privileges with no resources. This prevents us from sending an empty policy to AWS
164173 * and just assuming the role with full privileges.
165174 */
166- // TODO - add KMS key access
167175 private IamPolicy policyString (
168- String awsPartition ,
176+ AwsStorageConfigurationInfo storageConfigurationInfo ,
169177 boolean allowList ,
170178 Set <String > readLocations ,
171- Set <String > writeLocations ) {
179+ Set <String > writeLocations ,
180+ String region ,
181+ String accountId ) {
172182 IamPolicy .Builder policyBuilder = IamPolicy .builder ();
173183 IamStatement .Builder allowGetObjectStatementBuilder =
174184 IamStatement .builder ()
@@ -178,7 +188,9 @@ private IamPolicy policyString(
178188 Map <String , IamStatement .Builder > bucketListStatementBuilder = new HashMap <>();
179189 Map <String , IamStatement .Builder > bucketGetLocationStatementBuilder = new HashMap <>();
180190
181- String arnPrefix = arnPrefixForPartition (awsPartition );
191+ String arnPrefix = arnPrefixForPartition (storageConfigurationInfo .getAwsPartition ());
192+ String currentKmsKey = storageConfigurationInfo .getCurrentKmsKey ();
193+ List <String > allowedKmsKeys = storageConfigurationInfo .getAllowedKmsKeys ();
182194 Stream .concat (readLocations .stream (), writeLocations .stream ())
183195 .distinct ()
184196 .forEach (
@@ -225,6 +237,9 @@ private IamPolicy policyString(
225237 arnPrefix + StorageUtil .concatFilePrefixes (parseS3Path (uri ), "*" , "/" )));
226238 });
227239 policyBuilder .addStatement (allowPutObjectStatementBuilder .build ());
240+ addKmsKeyPolicy (currentKmsKey , allowedKmsKeys , policyBuilder , true , region , accountId );
241+ } else {
242+ addKmsKeyPolicy (currentKmsKey , allowedKmsKeys , policyBuilder , false , region , accountId );
228243 }
229244 if (!bucketListStatementBuilder .isEmpty ()) {
230245 bucketListStatementBuilder
@@ -242,6 +257,86 @@ private IamPolicy policyString(
242257 return policyBuilder .addStatement (allowGetObjectStatementBuilder .build ()).build ();
243258 }
244259
260+ private static void addKmsKeyPolicy (
261+ String kmsKeyArn ,
262+ List <String > allowedKmsKeys ,
263+ IamPolicy .Builder policyBuilder ,
264+ boolean canWrite ,
265+ String region ,
266+ String accountId ) {
267+
268+ IamStatement .Builder allowKms = buildBaseKmsStatement (canWrite );
269+ boolean hasCurrentKey = kmsKeyArn != null ;
270+ boolean hasAllowedKeys = hasAllowedKmsKeys (allowedKmsKeys );
271+
272+ if (hasCurrentKey ) {
273+ addKmsKeyResource (kmsKeyArn , allowKms );
274+ }
275+
276+ if (hasAllowedKeys ) {
277+ addAllowedKmsKeyResources (allowedKmsKeys , allowKms );
278+ }
279+
280+ // Add KMS statement if we have any KMS key configuration
281+ if (hasCurrentKey || hasAllowedKeys ) {
282+ policyBuilder .addStatement (allowKms .build ());
283+ } else if (!canWrite ) {
284+ // Only add wildcard KMS access for read-only operations when no specific keys are configured
285+ // this check is for minio because it doesn't have region or account id
286+ if (region != null && accountId != null ) {
287+ addAllKeysResource (region , accountId , allowKms );
288+ policyBuilder .addStatement (allowKms .build ());
289+ }
290+ }
291+ }
292+
293+ private static IamStatement .Builder buildBaseKmsStatement (boolean canEncrypt ) {
294+ IamStatement .Builder allowKms =
295+ IamStatement .builder ()
296+ .effect (IamEffect .ALLOW )
297+ .addAction ("kms:GenerateDataKeyWithoutPlaintext" )
298+ .addAction ("kms:DescribeKey" )
299+ .addAction ("kms:Decrypt" )
300+ .addAction ("kms:GenerateDataKey" );
301+
302+ if (canEncrypt ) {
303+ allowKms .addAction ("kms:Encrypt" );
304+ }
305+
306+ return allowKms ;
307+ }
308+
309+ private static void addKmsKeyResource (String kmsKeyArn , IamStatement .Builder allowKms ) {
310+ if (kmsKeyArn != null ) {
311+ LOGGER .debug ("Adding KMS key policy for key {}" , kmsKeyArn );
312+ allowKms .addResource (IamResource .create (kmsKeyArn ));
313+ }
314+ }
315+
316+ private static boolean hasAllowedKmsKeys (List <String > allowedKmsKeys ) {
317+ return allowedKmsKeys != null && !allowedKmsKeys .isEmpty ();
318+ }
319+
320+ private static void addAllowedKmsKeyResources (
321+ List <String > allowedKmsKeys , IamStatement .Builder allowKms ) {
322+ allowedKmsKeys .forEach (
323+ keyArn -> {
324+ LOGGER .debug ("Adding allowed KMS key policy for key {}" , keyArn );
325+ allowKms .addResource (IamResource .create (keyArn ));
326+ });
327+ }
328+
329+ private static void addAllKeysResource (
330+ String region , String accountId , IamStatement .Builder allowKms ) {
331+ String allKeysArn = arnKeyAll (region , accountId );
332+ allowKms .addResource (IamResource .create (allKeysArn ));
333+ LOGGER .debug ("Adding KMS key policy for all keys in account {}" , accountId );
334+ }
335+
336+ private static String arnKeyAll (String region , String accountId ) {
337+ return String .format ("arn:aws:kms:%s:%s:key/*" , region , accountId );
338+ }
339+
245340 private static String arnPrefixForPartition (String awsPartition ) {
246341 return String .format ("arn:%s:s3:::" , awsPartition != null ? awsPartition : "aws" );
247342 }
0 commit comments