2323import jakarta .annotation .Nonnull ;
2424import java .net .URI ;
2525import java .util .HashMap ;
26+ import java .util .List ;
2627import java .util .Map ;
2728import java .util .Optional ;
2829import java .util .Set ;
3334import org .apache .polaris .core .storage .StorageAccessProperty ;
3435import org .apache .polaris .core .storage .StorageUtil ;
3536import org .apache .polaris .core .storage .aws .StsClientProvider .StsDestination ;
37+ import org .slf4j .Logger ;
38+ import org .slf4j .LoggerFactory ;
3639import software .amazon .awssdk .auth .credentials .AwsCredentialsProvider ;
3740import software .amazon .awssdk .policybuilder .iam .IamConditionOperator ;
3841import software .amazon .awssdk .policybuilder .iam .IamEffect ;
@@ -49,6 +52,9 @@ public class AwsCredentialsStorageIntegration
4952 private final StsClientProvider stsClientProvider ;
5053 private final Optional <AwsCredentialsProvider > credentialsProvider ;
5154
55+ private static final Logger LOGGER =
56+ LoggerFactory .getLogger (AwsCredentialsStorageIntegration .class );
57+
5258 public AwsCredentialsStorageIntegration (
5359 AwsStorageConfigurationInfo config , StsClient fixedClient ) {
5460 this (config , (destination ) -> fixedClient );
@@ -80,6 +86,7 @@ public StorageAccessConfig getSubscopedCreds(
8086 realmConfig .getConfig (STORAGE_CREDENTIAL_DURATION_SECONDS );
8187 AwsStorageConfigurationInfo storageConfig = config ();
8288 String region = storageConfig .getRegion ();
89+ String accountId = storageConfig .getAwsAccountId ();
8390 StorageAccessConfig .Builder accessConfig = StorageAccessConfig .builder ();
8491
8592 if (shouldUseSts (storageConfig )) {
@@ -90,10 +97,12 @@ public StorageAccessConfig getSubscopedCreds(
9097 .roleSessionName ("PolarisAwsCredentialsStorageIntegration" )
9198 .policy (
9299 policyString (
93- storageConfig . getAwsPartition () ,
100+ storageConfig ,
94101 allowListOperation ,
95102 allowedReadLocations ,
96- allowedWriteLocations )
103+ allowedWriteLocations ,
104+ region ,
105+ accountId )
97106 .toJson ())
98107 .durationSeconds (storageCredentialDurationSeconds );
99108 credentialsProvider .ifPresent (
@@ -163,12 +172,13 @@ private boolean shouldUseSts(AwsStorageConfigurationInfo storageConfig) {
163172 * ListBucket privileges with no resources. This prevents us from sending an empty policy to AWS
164173 * and just assuming the role with full privileges.
165174 */
166- // TODO - add KMS key access
167175 private IamPolicy policyString (
168- String awsPartition ,
176+ AwsStorageConfigurationInfo storageConfigurationInfo ,
169177 boolean allowList ,
170178 Set <String > readLocations ,
171- Set <String > writeLocations ) {
179+ Set <String > writeLocations ,
180+ String region ,
181+ String accountId ) {
172182 IamPolicy .Builder policyBuilder = IamPolicy .builder ();
173183 IamStatement .Builder allowGetObjectStatementBuilder =
174184 IamStatement .builder ()
@@ -178,7 +188,9 @@ private IamPolicy policyString(
178188 Map <String , IamStatement .Builder > bucketListStatementBuilder = new HashMap <>();
179189 Map <String , IamStatement .Builder > bucketGetLocationStatementBuilder = new HashMap <>();
180190
181- String arnPrefix = arnPrefixForPartition (awsPartition );
191+ String arnPrefix = arnPrefixForPartition (storageConfigurationInfo .getAwsPartition ());
192+ String currentKmsKey = storageConfigurationInfo .currentKmsKey ();
193+ List <String > allowedKmsKeys = storageConfigurationInfo .allowedKmsKeys ();
182194 Stream .concat (readLocations .stream (), writeLocations .stream ())
183195 .distinct ()
184196 .forEach (
@@ -225,6 +237,9 @@ private IamPolicy policyString(
225237 arnPrefix + StorageUtil .concatFilePrefixes (parseS3Path (uri ), "*" , "/" )));
226238 });
227239 policyBuilder .addStatement (allowPutObjectStatementBuilder .build ());
240+ addKmsKeyPolicy (currentKmsKey , allowedKmsKeys , policyBuilder , true , region , accountId );
241+ } else {
242+ addKmsKeyPolicy (currentKmsKey , allowedKmsKeys , policyBuilder , false , region , accountId );
228243 }
229244 if (!bucketListStatementBuilder .isEmpty ()) {
230245 bucketListStatementBuilder
@@ -239,7 +254,86 @@ private IamPolicy policyString(
239254 bucketGetLocationStatementBuilder
240255 .values ()
241256 .forEach (statementBuilder -> policyBuilder .addStatement (statementBuilder .build ()));
242- return policyBuilder .addStatement (allowGetObjectStatementBuilder .build ()).build ();
257+ var r = policyBuilder .addStatement (allowGetObjectStatementBuilder .build ()).build ();
258+ LOGGER .info ("Policies {}" , r );
259+ return r ;
260+ }
261+
262+ private static void addKmsKeyPolicy (
263+ String kmsKeyArn ,
264+ List <String > allowedKmsKeys ,
265+ IamPolicy .Builder policyBuilder ,
266+ boolean canEncrypt ,
267+ String region ,
268+ String accountId ) {
269+
270+ IamStatement .Builder allowKms = buildBaseKmsStatement (canEncrypt );
271+ boolean hasCurrentKey = kmsKeyArn != null ;
272+ boolean hasAllowedKeys = hasAllowedKmsKeys (allowedKmsKeys );
273+
274+ if (hasCurrentKey ) {
275+ addKmsKeyResource (kmsKeyArn , allowKms );
276+ }
277+
278+ if (hasAllowedKeys ) {
279+ addAllowedKmsKeyResources (allowedKmsKeys , allowKms );
280+ }
281+
282+ // Add KMS statement if we have any KMS key configuration
283+ if (hasCurrentKey || hasAllowedKeys ) {
284+ policyBuilder .addStatement (allowKms .build ());
285+ } else if (!canEncrypt ) {
286+ // Only add wildcard KMS access for read-only operations when no specific keys are configured
287+ addAllKeysResource (region , accountId , allowKms );
288+ policyBuilder .addStatement (allowKms .build ());
289+ }
290+ }
291+
292+ private static IamStatement .Builder buildBaseKmsStatement (boolean canEncrypt ) {
293+ IamStatement .Builder allowKms =
294+ IamStatement .builder ()
295+ .effect (IamEffect .ALLOW )
296+ .addAction ("kms:GenerateDataKeyWithoutPlaintext" )
297+ .addAction ("kms:DescribeKey" )
298+ .addAction ("kms:Decrypt" )
299+ .addAction ("kms:GenerateDataKey" );
300+
301+ if (canEncrypt ) {
302+ allowKms .addAction ("kms:Encrypt" );
303+ }
304+
305+ return allowKms ;
306+ }
307+
308+ private static void addKmsKeyResource (String kmsKeyArn , IamStatement .Builder allowKms ) {
309+ if (kmsKeyArn != null ) {
310+ LOGGER .info ("Adding KMS key policy for key {}" , kmsKeyArn );
311+ allowKms .addResource (IamResource .create (kmsKeyArn ));
312+ }
313+ }
314+
315+ private static boolean hasAllowedKmsKeys (List <String > allowedKmsKeys ) {
316+ return allowedKmsKeys != null && !allowedKmsKeys .isEmpty ();
317+ }
318+
319+ private static void addAllowedKmsKeyResources (
320+ List <String > allowedKmsKeys , IamStatement .Builder allowKms ) {
321+ allowedKmsKeys .forEach (
322+ keyArn -> {
323+ LOGGER .info ("Adding allowed KMS key policy for key {}" , keyArn );
324+ allowKms .addResource (IamResource .create (keyArn ));
325+ });
326+ }
327+
328+ private static void addAllKeysResource (
329+ String region , String accountId , IamStatement .Builder allowKms ) {
330+ String allKeysArn = arnKeyAll (region , accountId );
331+ allowKms .addResource (IamResource .create (allKeysArn ));
332+ LOGGER .info ("Adding KMS key policy for all keys in account {}" , accountId );
333+ }
334+
335+ private static String arnKeyAll (String region , String accountId ) {
336+ return String .format ("arn:aws:kms:%s:%s:key/*" , region , accountId );
243337 }
244338
245339 private static String arnPrefixForPartition (String awsPartition ) {
0 commit comments