diff --git a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java index d7d7e071f..62ff53d45 100644 --- a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java +++ b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java @@ -13,7 +13,6 @@ package com.amazonaws.crypto.examples.datakeycaching; -import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.AwsCryptoResult; @@ -22,8 +21,8 @@ import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache; import com.amazonaws.encryptionsdk.keyrings.Keyring; import com.amazonaws.encryptionsdk.keyrings.StandardKeyrings; -import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import com.amazonaws.regions.Region; import com.amazonaws.services.kinesis.AmazonKinesis; import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; @@ -32,7 +31,6 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -73,10 +71,10 @@ public MultiRegionRecordPusherExample(final Region[] regions, final String kmsAl .build()); keyrings.add(StandardKeyrings.awsKmsBuilder() - .awsKmsClientSupplier(AwsKmsClientSupplier.builder() - .credentialsProvider(credentialsProvider) - .allowedRegions(Collections.singleton(region.getName())) - .build()) + .awsKmsClientSupplier(StandardAwsKmsClientSuppliers + .allowRegionsBuilder(Collections.singleton(region.getName())) + .baseClientSupplier(StandardAwsKmsClientSuppliers.defaultBuilder() + .credentialsProvider(credentialsProvider).build()).build()) .generatorKeyId(AwsKmsCmkId.fromString(kmsAliasName)).build()); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java index 483a13506..5a67334e8 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java @@ -16,6 +16,7 @@ import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import java.util.List; @@ -111,7 +112,7 @@ public AwsKmsKeyringBuilder generatorKeyId(AwsKmsCmkId generatorKeyId) { */ public Keyring build() { if (awsKmsClientSupplier == null) { - awsKmsClientSupplier = AwsKmsClientSupplier.builder().build(); + awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build(); } return new AwsKmsKeyring(DataKeyEncryptionDao.awsKms(awsKmsClientSupplier, grantTokens), diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java index 9cf37ea12..1a3fc16f6 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java @@ -13,8 +13,8 @@ package com.amazonaws.encryptionsdk.keyrings; -import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import java.util.Arrays; import java.util.List; @@ -80,15 +80,14 @@ public static AwsKmsKeyringBuilder awsKmsBuilder() { * AWS KMS Discovery keyrings do not specify any CMKs to decrypt with, and thus will attempt to decrypt * using any encrypted data key in an encrypted message. AWS KMS Discovery keyrings do not perform encryption. *

- * To create an AWS KMS Regional Discovery Keyring, construct an {@link AwsKmsClientSupplier} using - * {@link AwsKmsClientSupplier#builder()} to specify which regions to include/exclude. + * To create an AWS KMS Regional Discovery Keyring, use {@link StandardAwsKmsClientSuppliers#allowRegionsBuilder} or + * {@link StandardAwsKmsClientSuppliers#denyRegionsBuilder} to specify which regions to include/exclude. *

* For example, to include only CMKs in the us-east-1 region: *
      * StandardKeyrings.awsKmsDiscovery()
      *             .awsKmsClientSupplier(
-     *                     AwsKmsClientSupplier.builder()
-     *                     .allowedRegions(Collections.singleton("us-east-1")).build())
+     *                     StandardAwsKmsClientSuppliers.allowRegionsBuilder(Collections.singleton("us-east-1")).build()
      *             .build();
      * 
* diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java index ceccb5936..92a0bfee1 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java @@ -13,26 +13,13 @@ package com.amazonaws.encryptionsdk.kms; -import com.amazonaws.ClientConfiguration; import com.amazonaws.arn.Arn; -import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.AWSKMSClientBuilder; -import com.amazonaws.services.kms.model.AWSKMSException; import javax.annotation.Nullable; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Proxy; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; import static java.util.Objects.requireNonNull; -import static org.apache.commons.lang3.Validate.isTrue; -import static org.apache.commons.lang3.Validate.notEmpty; /** * Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The @@ -51,15 +38,6 @@ public interface AwsKmsClientSupplier { */ AWSKMS getClient(@Nullable String regionId) throws UnsupportedRegionException; - /** - * Gets a Builder for constructing an AwsKmsClientSupplier - * - * @return The builder - */ - static Builder builder() { - return new Builder(AWSKMSClientBuilder.standard()); - } - /** * Parses region from the given key id (if possible) and passes that region to the * given clientSupplier to produce an {@code AWSKMS} client. @@ -78,156 +56,4 @@ static AWSKMS getClientByKeyId(AwsKmsCmkId keyId, AwsKmsClientSupplier clientSup return clientSupplier.getClient(null); } - - /** - * Builder to construct an AwsKmsClientSupplier given various - * optional settings. - */ - class Builder { - - private AWSCredentialsProvider credentialsProvider; - private ClientConfiguration clientConfiguration; - private Set allowedRegions = Collections.emptySet(); - private Set excludedRegions = Collections.emptySet(); - private boolean clientCachingEnabled = true; - private final Map clientsCache = new HashMap<>(); - private static final Set AWSKMS_METHODS = new HashSet<>(); - private AWSKMSClientBuilder awsKmsClientBuilder; - - static { - AWSKMS_METHODS.add("generateDataKey"); - AWSKMS_METHODS.add("encrypt"); - AWSKMS_METHODS.add("decrypt"); - } - - Builder(AWSKMSClientBuilder awsKmsClientBuilder) { - this.awsKmsClientBuilder = awsKmsClientBuilder; - } - - public AwsKmsClientSupplier build() { - isTrue(allowedRegions.isEmpty() || excludedRegions.isEmpty(), - "Either allowed regions or excluded regions may be set, not both."); - - return regionId -> { - if (!allowedRegions.isEmpty() && !allowedRegions.contains(regionId)) { - throw new UnsupportedRegionException(String.format("Region %s is not in the list of allowed regions %s", - regionId, allowedRegions)); - } - - if (excludedRegions.contains(regionId)) { - throw new UnsupportedRegionException(String.format("Region %s is in the list of excluded regions %s", - regionId, excludedRegions)); - } - - if (clientsCache.containsKey(regionId)) { - return clientsCache.get(regionId); - } - - if (credentialsProvider != null) { - awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider); - } - - if (clientConfiguration != null) { - awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration); - } - - if (regionId != null) { - awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId); - } - - AWSKMS client = awsKmsClientBuilder.build(); - - if (clientCachingEnabled) { - client = newCachingProxy(client, regionId); - } - - return client; - }; - } - - /** - * Sets the AWSCredentialsProvider used by the client. - * - * @param credentialsProvider New AWSCredentialsProvider to use. - */ - public Builder credentialsProvider(AWSCredentialsProvider credentialsProvider) { - this.credentialsProvider = credentialsProvider; - return this; - } - - /** - * Sets the ClientConfiguration to be used by the client. - * - * @param clientConfiguration Custom configuration to use. - */ - public Builder clientConfiguration(ClientConfiguration clientConfiguration) { - this.clientConfiguration = clientConfiguration; - return this; - } - - /** - * Sets the AWS regions that the client supplier is permitted to use. - * - * @param regions The set of regions. - */ - public Builder allowedRegions(Set regions) { - notEmpty(regions, "At least one region is required"); - this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(regions)); - return this; - } - - /** - * Sets the AWS regions that the client supplier is not permitted to use. - * - * @param regions The set of regions. - */ - public Builder excludedRegions(Set regions) { - requireNonNull(regions, "regions is required"); - this.excludedRegions = Collections.unmodifiableSet(new HashSet<>(regions)); - return this; - } - - /** - * When set to false, disables the AWSKMS client for each region from being cached and reused. - * By default, client caching is enabled. - * - * @param enabled Whether or not caching is enabled. - */ - public Builder clientCaching(boolean enabled) { - this.clientCachingEnabled = enabled; - return this; - } - - /** - * Creates a proxy for the AWSKMS client that will populate the client into the client cache - * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a - * a malicious user from causing a local resource DOS by sending ciphertext with a large number - * of spurious regions, thereby filling the cache with regions and exhausting resources. - * - * @param client The client to proxy - * @param regionId The region the client is associated with - * @return The proxy - */ - private AWSKMS newCachingProxy(AWSKMS client, String regionId) { - return (AWSKMS) Proxy.newProxyInstance( - AWSKMS.class.getClassLoader(), - new Class[]{AWSKMS.class}, - (proxy, method, methodArgs) -> { - try { - final Object result = method.invoke(client, methodArgs); - if (AWSKMS_METHODS.contains(method.getName())) { - clientsCache.put(regionId, client); - } - return result; - } catch (InvocationTargetException e) { - if (e.getTargetException() instanceof AWSKMSException && - AWSKMS_METHODS.contains(method.getName())) { - clientsCache.put(regionId, client); - } - - throw e.getTargetException(); - } - }); - } - } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java new file mode 100644 index 000000000..24b71c45c --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java @@ -0,0 +1,257 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClientBuilder; +import com.amazonaws.services.kms.model.AWSKMSException; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Proxy; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.notEmpty; + +/** + * Factory methods for instantiating the standard {@code AwsKmsClientSupplier}s provided by the AWS Encryption SDK. + */ +public class StandardAwsKmsClientSuppliers { + + /** + * A builder to construct the default AwsKmsClientSupplier that will create and cache clients + * for any region. Credentials and client configuration may be specified if necessary. + * + * @return The builder + */ + public static DefaultAwsKmsClientSupplierBuilder defaultBuilder() { + return new DefaultAwsKmsClientSupplierBuilder(AWSKMSClientBuilder.standard()); + } + + /** + * A builder to construct an AwsKmsClientSupplier that will + * only supply clients for a given set of AWS regions. + * + * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for + * @return The builder + */ + public static AllowRegionsAwsKmsClientSupplierBuilder allowRegionsBuilder(Set allowedRegions) { + return new AllowRegionsAwsKmsClientSupplierBuilder(allowedRegions); + } + + /** + * A builder to construct an AwsKmsClientSupplier that will + * supply clients for all AWS regions except the given set of regions. + * + * @param deniedRegions the AWS regions that the client supplier will not supply clients for + * @return The builder + */ + public static DenyRegionsAwsKmsClientSupplierBuilder denyRegionsBuilder(Set deniedRegions) { + return new DenyRegionsAwsKmsClientSupplierBuilder(deniedRegions); + } + + + /** + * Builder to construct an AwsKmsClientSupplier that will create and cache clients + * for any region. CredentialProvider and ClientConfiguration are optional and may + * be configured if necessary. + */ + public static class DefaultAwsKmsClientSupplierBuilder { + + private AWSCredentialsProvider credentialsProvider; + private ClientConfiguration clientConfiguration; + private final Map clientsCache = new ConcurrentHashMap<>(); + private static final Set AWSKMS_METHODS = new HashSet<>(); + private AWSKMSClientBuilder awsKmsClientBuilder; + private static final String NULL_REGION = "null-region"; + + static { + AWSKMS_METHODS.add("generateDataKey"); + AWSKMS_METHODS.add("encrypt"); + AWSKMS_METHODS.add("decrypt"); + } + + DefaultAwsKmsClientSupplierBuilder(AWSKMSClientBuilder awsKmsClientBuilder) { + this.awsKmsClientBuilder = awsKmsClientBuilder; + } + + public AwsKmsClientSupplier build() { + + return regionId -> { + + if(regionId == null) { + regionId = NULL_REGION; + } + + if (clientsCache.containsKey(regionId)) { + return clientsCache.get(regionId); + } + + if (credentialsProvider != null) { + awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider); + } + + if (clientConfiguration != null) { + awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration); + } + + if (!regionId.equals(NULL_REGION)) { + awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId); + } + + return newCachingProxy(awsKmsClientBuilder.build(), regionId); + }; + } + + /** + * Sets the AWSCredentialsProvider used by the client. + * + * @param credentialsProvider New AWSCredentialsProvider to use. + */ + public DefaultAwsKmsClientSupplierBuilder credentialsProvider(AWSCredentialsProvider credentialsProvider) { + this.credentialsProvider = credentialsProvider; + return this; + } + + /** + * Sets the ClientConfiguration to be used by the client. + * + * @param clientConfiguration Custom configuration to use. + */ + public DefaultAwsKmsClientSupplierBuilder clientConfiguration(ClientConfiguration clientConfiguration) { + this.clientConfiguration = clientConfiguration; + return this; + } + + /** + * Creates a proxy for the AWSKMS client that will populate the client into the client cache + * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a + * a malicious user from causing a local resource DOS by sending ciphertext with a large number + * of spurious regions, thereby filling the cache with regions and exhausting resources. + * + * @param client The client to proxy + * @param regionId The region the client is associated with + * @return The proxy + */ + private AWSKMS newCachingProxy(AWSKMS client, String regionId) { + return (AWSKMS) Proxy.newProxyInstance( + AWSKMS.class.getClassLoader(), + new Class[]{AWSKMS.class}, + (proxy, method, methodArgs) -> { + try { + final Object result = method.invoke(client, methodArgs); + if (AWSKMS_METHODS.contains(method.getName())) { + clientsCache.put(regionId, client); + } + return result; + } catch (InvocationTargetException e) { + if (e.getTargetException() instanceof AWSKMSException && + AWSKMS_METHODS.contains(method.getName())) { + clientsCache.put(regionId, client); + } + + throw e.getTargetException(); + } + }); + } + } + + /** + * An AwsKmsClientSupplier that will only supply clients for a given set of AWS regions. + */ + public static class AllowRegionsAwsKmsClientSupplierBuilder { + + private final Set allowedRegions; + private AwsKmsClientSupplier baseClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build(); + + private AllowRegionsAwsKmsClientSupplierBuilder(Set allowedRegions) { + notEmpty(allowedRegions, "At least one region is required"); + requireNonNull(baseClientSupplier, "baseClientSupplier is required"); + + this.allowedRegions = allowedRegions; + } + + /** + * Constructs the AwsKmsClientSupplier. + * + * @return The AwsKmsClientSupplier + */ + public AwsKmsClientSupplier build() { + return regionId -> { + + if (!allowedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is not in the set of allowed regions %s", + regionId, allowedRegions)); + } + + return baseClientSupplier.getClient(regionId); + }; + } + + /** + * Sets the client supplier that will supply the client if the region is allowed. + * + * @param baseClientSupplier the client supplier that will supply the client if the region is allowed. + */ + public AllowRegionsAwsKmsClientSupplierBuilder baseClientSupplier(AwsKmsClientSupplier baseClientSupplier) { + this.baseClientSupplier = baseClientSupplier; + return this; + } + } + + /** + * A client supplier that supplies clients for any region except the specified AWS regions. + */ + public static class DenyRegionsAwsKmsClientSupplierBuilder { + + private final Set deniedRegions; + private AwsKmsClientSupplier baseClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build(); + + private DenyRegionsAwsKmsClientSupplierBuilder(Set deniedRegions) { + notEmpty(deniedRegions, "At least one region is required"); + requireNonNull(baseClientSupplier, "baseClientSupplier is required"); + + this.deniedRegions = deniedRegions; + } + + /** + * Sets the client supplier that will supply the client if the region is allowed. + * + * @param baseClientSupplier the client supplier that will supply the client if the region is allowed. + */ + public DenyRegionsAwsKmsClientSupplierBuilder baseClientSupplier(AwsKmsClientSupplier baseClientSupplier) { + this.baseClientSupplier = baseClientSupplier; + return this; + } + + public AwsKmsClientSupplier build() { + + return regionId -> { + + if (deniedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is in the set of denied regions %s", + regionId, deniedRegions)); + } + + return baseClientSupplier.getClient(regionId); + }; + } + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java index 078dff781..89f21479b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java +++ b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java @@ -21,6 +21,7 @@ import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; import com.amazonaws.util.IOUtils; import com.fasterxml.jackson.core.type.TypeReference; @@ -64,7 +65,7 @@ class TestVectorRunner { // We save the files in memory to avoid repeatedly retrieving them. // This won't work if the plaintexts are too large or numerous private static final Map cachedData = new HashMap<>(); - private static final AwsKmsClientSupplier awsKmsClientSupplier = AwsKmsClientSupplier.builder() + private static final AwsKmsClientSupplier awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder() .credentialsProvider(new DefaultAWSCredentialsProviderChain()) .build(); private static final KmsMasterKeyProvider kmsProv = KmsMasterKeyProvider diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java similarity index 80% rename from src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java rename to src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java index 748077d40..dec01d5a3 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java @@ -16,6 +16,7 @@ import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers.DefaultAwsKmsClientSupplierBuilder; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; import com.amazonaws.services.kms.model.AWSKMSException; @@ -39,7 +40,7 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) -class AwsKmsClientSupplierTest { +class StandardAwsKmsClientSuppliersTest { @Mock AWSKMSClientBuilder kmsClientBuilder; @Mock AWSKMS awskms; @@ -58,7 +59,7 @@ void testCredentialsAndClientConfiguration() { when(kmsClientBuilder.withCredentials(credentialsProvider)).thenReturn(kmsClientBuilder); when(kmsClientBuilder.build()).thenReturn(awskms); - AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder) + AwsKmsClientSupplier supplier = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) .credentialsProvider(credentialsProvider) .clientConfiguration(clientConfiguration) .build(); @@ -70,64 +71,9 @@ void testCredentialsAndClientConfiguration() { verify(kmsClientBuilder).build(); } - @Test - void testAllowedAndExcludedRegions() { - AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); - - AwsKmsClientSupplier supplierWithAllowed = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .allowedRegions(Collections.singleton(REGION_1)) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithAllowed.getClient(REGION_1)); - assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); - - AwsKmsClientSupplier supplierWithExcluded = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .excludedRegions(Collections.singleton(REGION_1)) - .build(); - - when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertThrows(UnsupportedRegionException.class, () -> supplierWithExcluded.getClient(REGION_1)); - assertNotNull(supplierWithExcluded.getClient(REGION_2)); - - assertThrows(IllegalArgumentException.class, () -> new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .allowedRegions(Collections.singleton(REGION_1)) - .excludedRegions(Collections.singleton(REGION_2)) - .build()); - } - - @Test - void testClientCachingDisabled() { - AwsKmsClientSupplier supplierCachingDisabled = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .clientCaching(false) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - AWSKMS uncachedClient = supplierCachingDisabled.getClient(REGION_1); - verify(kmsClientBuilder, times(1)).build(); - - when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult()); - - uncachedClient.encrypt(encryptRequest); - supplierCachingDisabled.getClient(REGION_1); - verify(kmsClientBuilder, times(2)).build(); - } - @Test void testClientCaching() { - AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder) + AwsKmsClientSupplier supplier = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) .build(); when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); @@ -184,4 +130,46 @@ void testGetClientByKeyId() { assertEquals(awskms, AwsKmsClientSupplier.getClientByKeyId(AwsKmsCmkId.fromString(alias), s -> awskms)); assertEquals(awskms, AwsKmsClientSupplier.getClientByKeyId(AwsKmsCmkId.fromString(keyId), s -> awskms)); } + + @Test + void testAllowedRegions() { + AwsKmsClientSupplier supplierWithDefaultValues = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + AwsKmsClientSupplier supplierWithAllowed = StandardAwsKmsClientSuppliers + .allowRegionsBuilder(Collections.singleton(REGION_1)) + .baseClientSupplier(new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder).build()).build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithAllowed.getClient(REGION_1)); + assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); + } + + @Test + void testDeniedRegions() { + AwsKmsClientSupplier supplierWithDefaultValues = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + AwsKmsClientSupplier supplierWithDenied = StandardAwsKmsClientSuppliers + .denyRegionsBuilder(Collections.singleton(REGION_1)) + .baseClientSupplier(new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder).build()).build(); + + when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertThrows(UnsupportedRegionException.class, () -> supplierWithDenied.getClient(REGION_1)); + assertNotNull(supplierWithDenied.getClient(REGION_2)); + } }