diff --git a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java
index d7d7e071f..62ff53d45 100644
--- a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java
+++ b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java
@@ -13,7 +13,6 @@
package com.amazonaws.crypto.examples.datakeycaching;
-import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.encryptionsdk.AwsCrypto;
import com.amazonaws.encryptionsdk.AwsCryptoResult;
@@ -22,8 +21,8 @@
import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache;
import com.amazonaws.encryptionsdk.keyrings.Keyring;
import com.amazonaws.encryptionsdk.keyrings.StandardKeyrings;
-import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
+import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;
import com.amazonaws.regions.Region;
import com.amazonaws.services.kinesis.AmazonKinesis;
import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder;
@@ -32,7 +31,6 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@@ -73,10 +71,10 @@ public MultiRegionRecordPusherExample(final Region[] regions, final String kmsAl
.build());
keyrings.add(StandardKeyrings.awsKmsBuilder()
- .awsKmsClientSupplier(AwsKmsClientSupplier.builder()
- .credentialsProvider(credentialsProvider)
- .allowedRegions(Collections.singleton(region.getName()))
- .build())
+ .awsKmsClientSupplier(StandardAwsKmsClientSuppliers
+ .allowRegionsBuilder(Collections.singleton(region.getName()))
+ .baseClientSupplier(StandardAwsKmsClientSuppliers.defaultBuilder()
+ .credentialsProvider(credentialsProvider).build()).build())
.generatorKeyId(AwsKmsCmkId.fromString(kmsAliasName)).build());
}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java
index 483a13506..5a67334e8 100644
--- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java
+++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java
@@ -16,6 +16,7 @@
import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
+import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;
import java.util.List;
@@ -111,7 +112,7 @@ public AwsKmsKeyringBuilder generatorKeyId(AwsKmsCmkId generatorKeyId) {
*/
public Keyring build() {
if (awsKmsClientSupplier == null) {
- awsKmsClientSupplier = AwsKmsClientSupplier.builder().build();
+ awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build();
}
return new AwsKmsKeyring(DataKeyEncryptionDao.awsKms(awsKmsClientSupplier, grantTokens),
diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java
index 9cf37ea12..1a3fc16f6 100644
--- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java
+++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java
@@ -13,8 +13,8 @@
package com.amazonaws.encryptionsdk.keyrings;
-import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
+import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;
import java.util.Arrays;
import java.util.List;
@@ -80,15 +80,14 @@ public static AwsKmsKeyringBuilder awsKmsBuilder() {
* AWS KMS Discovery keyrings do not specify any CMKs to decrypt with, and thus will attempt to decrypt
* using any encrypted data key in an encrypted message. AWS KMS Discovery keyrings do not perform encryption.
*
- * To create an AWS KMS Regional Discovery Keyring, construct an {@link AwsKmsClientSupplier} using
- * {@link AwsKmsClientSupplier#builder()} to specify which regions to include/exclude.
+ * To create an AWS KMS Regional Discovery Keyring, use {@link StandardAwsKmsClientSuppliers#allowRegionsBuilder} or
+ * {@link StandardAwsKmsClientSuppliers#denyRegionsBuilder} to specify which regions to include/exclude.
*
* For example, to include only CMKs in the us-east-1 region:
*
* StandardKeyrings.awsKmsDiscovery()
* .awsKmsClientSupplier(
- * AwsKmsClientSupplier.builder()
- * .allowedRegions(Collections.singleton("us-east-1")).build())
+ * StandardAwsKmsClientSuppliers.allowRegionsBuilder(Collections.singleton("us-east-1")).build()
* .build();
*
*
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java
index ceccb5936..92a0bfee1 100644
--- a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java
@@ -13,26 +13,13 @@
package com.amazonaws.encryptionsdk.kms;
-import com.amazonaws.ClientConfiguration;
import com.amazonaws.arn.Arn;
-import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
import com.amazonaws.services.kms.AWSKMS;
-import com.amazonaws.services.kms.AWSKMSClientBuilder;
-import com.amazonaws.services.kms.model.AWSKMSException;
import javax.annotation.Nullable;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Proxy;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
import static java.util.Objects.requireNonNull;
-import static org.apache.commons.lang3.Validate.isTrue;
-import static org.apache.commons.lang3.Validate.notEmpty;
/**
* Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The
@@ -51,15 +38,6 @@ public interface AwsKmsClientSupplier {
*/
AWSKMS getClient(@Nullable String regionId) throws UnsupportedRegionException;
- /**
- * Gets a Builder for constructing an AwsKmsClientSupplier
- *
- * @return The builder
- */
- static Builder builder() {
- return new Builder(AWSKMSClientBuilder.standard());
- }
-
/**
* Parses region from the given key id (if possible) and passes that region to the
* given clientSupplier to produce an {@code AWSKMS} client.
@@ -78,156 +56,4 @@ static AWSKMS getClientByKeyId(AwsKmsCmkId keyId, AwsKmsClientSupplier clientSup
return clientSupplier.getClient(null);
}
-
- /**
- * Builder to construct an AwsKmsClientSupplier given various
- * optional settings.
- */
- class Builder {
-
- private AWSCredentialsProvider credentialsProvider;
- private ClientConfiguration clientConfiguration;
- private Set allowedRegions = Collections.emptySet();
- private Set excludedRegions = Collections.emptySet();
- private boolean clientCachingEnabled = true;
- private final Map clientsCache = new HashMap<>();
- private static final Set AWSKMS_METHODS = new HashSet<>();
- private AWSKMSClientBuilder awsKmsClientBuilder;
-
- static {
- AWSKMS_METHODS.add("generateDataKey");
- AWSKMS_METHODS.add("encrypt");
- AWSKMS_METHODS.add("decrypt");
- }
-
- Builder(AWSKMSClientBuilder awsKmsClientBuilder) {
- this.awsKmsClientBuilder = awsKmsClientBuilder;
- }
-
- public AwsKmsClientSupplier build() {
- isTrue(allowedRegions.isEmpty() || excludedRegions.isEmpty(),
- "Either allowed regions or excluded regions may be set, not both.");
-
- return regionId -> {
- if (!allowedRegions.isEmpty() && !allowedRegions.contains(regionId)) {
- throw new UnsupportedRegionException(String.format("Region %s is not in the list of allowed regions %s",
- regionId, allowedRegions));
- }
-
- if (excludedRegions.contains(regionId)) {
- throw new UnsupportedRegionException(String.format("Region %s is in the list of excluded regions %s",
- regionId, excludedRegions));
- }
-
- if (clientsCache.containsKey(regionId)) {
- return clientsCache.get(regionId);
- }
-
- if (credentialsProvider != null) {
- awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider);
- }
-
- if (clientConfiguration != null) {
- awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration);
- }
-
- if (regionId != null) {
- awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId);
- }
-
- AWSKMS client = awsKmsClientBuilder.build();
-
- if (clientCachingEnabled) {
- client = newCachingProxy(client, regionId);
- }
-
- return client;
- };
- }
-
- /**
- * Sets the AWSCredentialsProvider used by the client.
- *
- * @param credentialsProvider New AWSCredentialsProvider to use.
- */
- public Builder credentialsProvider(AWSCredentialsProvider credentialsProvider) {
- this.credentialsProvider = credentialsProvider;
- return this;
- }
-
- /**
- * Sets the ClientConfiguration to be used by the client.
- *
- * @param clientConfiguration Custom configuration to use.
- */
- public Builder clientConfiguration(ClientConfiguration clientConfiguration) {
- this.clientConfiguration = clientConfiguration;
- return this;
- }
-
- /**
- * Sets the AWS regions that the client supplier is permitted to use.
- *
- * @param regions The set of regions.
- */
- public Builder allowedRegions(Set regions) {
- notEmpty(regions, "At least one region is required");
- this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(regions));
- return this;
- }
-
- /**
- * Sets the AWS regions that the client supplier is not permitted to use.
- *
- * @param regions The set of regions.
- */
- public Builder excludedRegions(Set regions) {
- requireNonNull(regions, "regions is required");
- this.excludedRegions = Collections.unmodifiableSet(new HashSet<>(regions));
- return this;
- }
-
- /**
- * When set to false, disables the AWSKMS client for each region from being cached and reused.
- * By default, client caching is enabled.
- *
- * @param enabled Whether or not caching is enabled.
- */
- public Builder clientCaching(boolean enabled) {
- this.clientCachingEnabled = enabled;
- return this;
- }
-
- /**
- * Creates a proxy for the AWSKMS client that will populate the client into the client cache
- * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a
- * a malicious user from causing a local resource DOS by sending ciphertext with a large number
- * of spurious regions, thereby filling the cache with regions and exhausting resources.
- *
- * @param client The client to proxy
- * @param regionId The region the client is associated with
- * @return The proxy
- */
- private AWSKMS newCachingProxy(AWSKMS client, String regionId) {
- return (AWSKMS) Proxy.newProxyInstance(
- AWSKMS.class.getClassLoader(),
- new Class[]{AWSKMS.class},
- (proxy, method, methodArgs) -> {
- try {
- final Object result = method.invoke(client, methodArgs);
- if (AWSKMS_METHODS.contains(method.getName())) {
- clientsCache.put(regionId, client);
- }
- return result;
- } catch (InvocationTargetException e) {
- if (e.getTargetException() instanceof AWSKMSException &&
- AWSKMS_METHODS.contains(method.getName())) {
- clientsCache.put(regionId, client);
- }
-
- throw e.getTargetException();
- }
- });
- }
- }
}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java
new file mode 100644
index 000000000..24b71c45c
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java
@@ -0,0 +1,257 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.ClientConfiguration;
+import com.amazonaws.auth.AWSCredentialsProvider;
+import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
+import com.amazonaws.services.kms.AWSKMS;
+import com.amazonaws.services.kms.AWSKMSClientBuilder;
+import com.amazonaws.services.kms.model.AWSKMSException;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Proxy;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static java.util.Objects.requireNonNull;
+import static org.apache.commons.lang3.Validate.notEmpty;
+
+/**
+ * Factory methods for instantiating the standard {@code AwsKmsClientSupplier}s provided by the AWS Encryption SDK.
+ */
+public class StandardAwsKmsClientSuppliers {
+
+ /**
+ * A builder to construct the default AwsKmsClientSupplier that will create and cache clients
+ * for any region. Credentials and client configuration may be specified if necessary.
+ *
+ * @return The builder
+ */
+ public static DefaultAwsKmsClientSupplierBuilder defaultBuilder() {
+ return new DefaultAwsKmsClientSupplierBuilder(AWSKMSClientBuilder.standard());
+ }
+
+ /**
+ * A builder to construct an AwsKmsClientSupplier that will
+ * only supply clients for a given set of AWS regions.
+ *
+ * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for
+ * @return The builder
+ */
+ public static AllowRegionsAwsKmsClientSupplierBuilder allowRegionsBuilder(Set allowedRegions) {
+ return new AllowRegionsAwsKmsClientSupplierBuilder(allowedRegions);
+ }
+
+ /**
+ * A builder to construct an AwsKmsClientSupplier that will
+ * supply clients for all AWS regions except the given set of regions.
+ *
+ * @param deniedRegions the AWS regions that the client supplier will not supply clients for
+ * @return The builder
+ */
+ public static DenyRegionsAwsKmsClientSupplierBuilder denyRegionsBuilder(Set deniedRegions) {
+ return new DenyRegionsAwsKmsClientSupplierBuilder(deniedRegions);
+ }
+
+
+ /**
+ * Builder to construct an AwsKmsClientSupplier that will create and cache clients
+ * for any region. CredentialProvider and ClientConfiguration are optional and may
+ * be configured if necessary.
+ */
+ public static class DefaultAwsKmsClientSupplierBuilder {
+
+ private AWSCredentialsProvider credentialsProvider;
+ private ClientConfiguration clientConfiguration;
+ private final Map clientsCache = new ConcurrentHashMap<>();
+ private static final Set AWSKMS_METHODS = new HashSet<>();
+ private AWSKMSClientBuilder awsKmsClientBuilder;
+ private static final String NULL_REGION = "null-region";
+
+ static {
+ AWSKMS_METHODS.add("generateDataKey");
+ AWSKMS_METHODS.add("encrypt");
+ AWSKMS_METHODS.add("decrypt");
+ }
+
+ DefaultAwsKmsClientSupplierBuilder(AWSKMSClientBuilder awsKmsClientBuilder) {
+ this.awsKmsClientBuilder = awsKmsClientBuilder;
+ }
+
+ public AwsKmsClientSupplier build() {
+
+ return regionId -> {
+
+ if(regionId == null) {
+ regionId = NULL_REGION;
+ }
+
+ if (clientsCache.containsKey(regionId)) {
+ return clientsCache.get(regionId);
+ }
+
+ if (credentialsProvider != null) {
+ awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider);
+ }
+
+ if (clientConfiguration != null) {
+ awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration);
+ }
+
+ if (!regionId.equals(NULL_REGION)) {
+ awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId);
+ }
+
+ return newCachingProxy(awsKmsClientBuilder.build(), regionId);
+ };
+ }
+
+ /**
+ * Sets the AWSCredentialsProvider used by the client.
+ *
+ * @param credentialsProvider New AWSCredentialsProvider to use.
+ */
+ public DefaultAwsKmsClientSupplierBuilder credentialsProvider(AWSCredentialsProvider credentialsProvider) {
+ this.credentialsProvider = credentialsProvider;
+ return this;
+ }
+
+ /**
+ * Sets the ClientConfiguration to be used by the client.
+ *
+ * @param clientConfiguration Custom configuration to use.
+ */
+ public DefaultAwsKmsClientSupplierBuilder clientConfiguration(ClientConfiguration clientConfiguration) {
+ this.clientConfiguration = clientConfiguration;
+ return this;
+ }
+
+ /**
+ * Creates a proxy for the AWSKMS client that will populate the client into the client cache
+ * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a
+ * a malicious user from causing a local resource DOS by sending ciphertext with a large number
+ * of spurious regions, thereby filling the cache with regions and exhausting resources.
+ *
+ * @param client The client to proxy
+ * @param regionId The region the client is associated with
+ * @return The proxy
+ */
+ private AWSKMS newCachingProxy(AWSKMS client, String regionId) {
+ return (AWSKMS) Proxy.newProxyInstance(
+ AWSKMS.class.getClassLoader(),
+ new Class[]{AWSKMS.class},
+ (proxy, method, methodArgs) -> {
+ try {
+ final Object result = method.invoke(client, methodArgs);
+ if (AWSKMS_METHODS.contains(method.getName())) {
+ clientsCache.put(regionId, client);
+ }
+ return result;
+ } catch (InvocationTargetException e) {
+ if (e.getTargetException() instanceof AWSKMSException &&
+ AWSKMS_METHODS.contains(method.getName())) {
+ clientsCache.put(regionId, client);
+ }
+
+ throw e.getTargetException();
+ }
+ });
+ }
+ }
+
+ /**
+ * An AwsKmsClientSupplier that will only supply clients for a given set of AWS regions.
+ */
+ public static class AllowRegionsAwsKmsClientSupplierBuilder {
+
+ private final Set allowedRegions;
+ private AwsKmsClientSupplier baseClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build();
+
+ private AllowRegionsAwsKmsClientSupplierBuilder(Set allowedRegions) {
+ notEmpty(allowedRegions, "At least one region is required");
+ requireNonNull(baseClientSupplier, "baseClientSupplier is required");
+
+ this.allowedRegions = allowedRegions;
+ }
+
+ /**
+ * Constructs the AwsKmsClientSupplier.
+ *
+ * @return The AwsKmsClientSupplier
+ */
+ public AwsKmsClientSupplier build() {
+ return regionId -> {
+
+ if (!allowedRegions.contains(regionId)) {
+ throw new UnsupportedRegionException(String.format("Region %s is not in the set of allowed regions %s",
+ regionId, allowedRegions));
+ }
+
+ return baseClientSupplier.getClient(regionId);
+ };
+ }
+
+ /**
+ * Sets the client supplier that will supply the client if the region is allowed.
+ *
+ * @param baseClientSupplier the client supplier that will supply the client if the region is allowed.
+ */
+ public AllowRegionsAwsKmsClientSupplierBuilder baseClientSupplier(AwsKmsClientSupplier baseClientSupplier) {
+ this.baseClientSupplier = baseClientSupplier;
+ return this;
+ }
+ }
+
+ /**
+ * A client supplier that supplies clients for any region except the specified AWS regions.
+ */
+ public static class DenyRegionsAwsKmsClientSupplierBuilder {
+
+ private final Set deniedRegions;
+ private AwsKmsClientSupplier baseClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build();
+
+ private DenyRegionsAwsKmsClientSupplierBuilder(Set deniedRegions) {
+ notEmpty(deniedRegions, "At least one region is required");
+ requireNonNull(baseClientSupplier, "baseClientSupplier is required");
+
+ this.deniedRegions = deniedRegions;
+ }
+
+ /**
+ * Sets the client supplier that will supply the client if the region is allowed.
+ *
+ * @param baseClientSupplier the client supplier that will supply the client if the region is allowed.
+ */
+ public DenyRegionsAwsKmsClientSupplierBuilder baseClientSupplier(AwsKmsClientSupplier baseClientSupplier) {
+ this.baseClientSupplier = baseClientSupplier;
+ return this;
+ }
+
+ public AwsKmsClientSupplier build() {
+
+ return regionId -> {
+
+ if (deniedRegions.contains(regionId)) {
+ throw new UnsupportedRegionException(String.format("Region %s is in the set of denied regions %s",
+ regionId, deniedRegions));
+ }
+
+ return baseClientSupplier.getClient(regionId);
+ };
+ }
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java
index 078dff781..89f21479b 100644
--- a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java
+++ b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java
@@ -21,6 +21,7 @@
import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
+import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;
import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory;
import com.amazonaws.util.IOUtils;
import com.fasterxml.jackson.core.type.TypeReference;
@@ -64,7 +65,7 @@ class TestVectorRunner {
// We save the files in memory to avoid repeatedly retrieving them.
// This won't work if the plaintexts are too large or numerous
private static final Map cachedData = new HashMap<>();
- private static final AwsKmsClientSupplier awsKmsClientSupplier = AwsKmsClientSupplier.builder()
+ private static final AwsKmsClientSupplier awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder()
.credentialsProvider(new DefaultAWSCredentialsProviderChain())
.build();
private static final KmsMasterKeyProvider kmsProv = KmsMasterKeyProvider
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java
similarity index 80%
rename from src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java
rename to src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java
index 748077d40..dec01d5a3 100644
--- a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java
@@ -16,6 +16,7 @@
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
+import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers.DefaultAwsKmsClientSupplierBuilder;
import com.amazonaws.services.kms.AWSKMS;
import com.amazonaws.services.kms.AWSKMSClientBuilder;
import com.amazonaws.services.kms.model.AWSKMSException;
@@ -39,7 +40,7 @@
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
-class AwsKmsClientSupplierTest {
+class StandardAwsKmsClientSuppliersTest {
@Mock AWSKMSClientBuilder kmsClientBuilder;
@Mock AWSKMS awskms;
@@ -58,7 +59,7 @@ void testCredentialsAndClientConfiguration() {
when(kmsClientBuilder.withCredentials(credentialsProvider)).thenReturn(kmsClientBuilder);
when(kmsClientBuilder.build()).thenReturn(awskms);
- AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder)
+ AwsKmsClientSupplier supplier = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder)
.credentialsProvider(credentialsProvider)
.clientConfiguration(clientConfiguration)
.build();
@@ -70,64 +71,9 @@ void testCredentialsAndClientConfiguration() {
verify(kmsClientBuilder).build();
}
- @Test
- void testAllowedAndExcludedRegions() {
- AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder)
- .build();
-
- when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
- when(kmsClientBuilder.build()).thenReturn(awskms);
-
- assertNotNull(supplierWithDefaultValues.getClient(REGION_1));
-
- AwsKmsClientSupplier supplierWithAllowed = new AwsKmsClientSupplier.Builder(kmsClientBuilder)
- .allowedRegions(Collections.singleton(REGION_1))
- .build();
-
- when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
- when(kmsClientBuilder.build()).thenReturn(awskms);
-
- assertNotNull(supplierWithAllowed.getClient(REGION_1));
- assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2));
-
- AwsKmsClientSupplier supplierWithExcluded = new AwsKmsClientSupplier.Builder(kmsClientBuilder)
- .excludedRegions(Collections.singleton(REGION_1))
- .build();
-
- when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder);
- when(kmsClientBuilder.build()).thenReturn(awskms);
-
- assertThrows(UnsupportedRegionException.class, () -> supplierWithExcluded.getClient(REGION_1));
- assertNotNull(supplierWithExcluded.getClient(REGION_2));
-
- assertThrows(IllegalArgumentException.class, () -> new AwsKmsClientSupplier.Builder(kmsClientBuilder)
- .allowedRegions(Collections.singleton(REGION_1))
- .excludedRegions(Collections.singleton(REGION_2))
- .build());
- }
-
- @Test
- void testClientCachingDisabled() {
- AwsKmsClientSupplier supplierCachingDisabled = new AwsKmsClientSupplier.Builder(kmsClientBuilder)
- .clientCaching(false)
- .build();
-
- when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
- when(kmsClientBuilder.build()).thenReturn(awskms);
-
- AWSKMS uncachedClient = supplierCachingDisabled.getClient(REGION_1);
- verify(kmsClientBuilder, times(1)).build();
-
- when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult());
-
- uncachedClient.encrypt(encryptRequest);
- supplierCachingDisabled.getClient(REGION_1);
- verify(kmsClientBuilder, times(2)).build();
- }
-
@Test
void testClientCaching() {
- AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder)
+ AwsKmsClientSupplier supplier = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder)
.build();
when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
@@ -184,4 +130,46 @@ void testGetClientByKeyId() {
assertEquals(awskms, AwsKmsClientSupplier.getClientByKeyId(AwsKmsCmkId.fromString(alias), s -> awskms));
assertEquals(awskms, AwsKmsClientSupplier.getClientByKeyId(AwsKmsCmkId.fromString(keyId), s -> awskms));
}
+
+ @Test
+ void testAllowedRegions() {
+ AwsKmsClientSupplier supplierWithDefaultValues = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder)
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertNotNull(supplierWithDefaultValues.getClient(REGION_1));
+
+ AwsKmsClientSupplier supplierWithAllowed = StandardAwsKmsClientSuppliers
+ .allowRegionsBuilder(Collections.singleton(REGION_1))
+ .baseClientSupplier(new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder).build()).build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertNotNull(supplierWithAllowed.getClient(REGION_1));
+ assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2));
+ }
+
+ @Test
+ void testDeniedRegions() {
+ AwsKmsClientSupplier supplierWithDefaultValues = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder)
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertNotNull(supplierWithDefaultValues.getClient(REGION_1));
+
+ AwsKmsClientSupplier supplierWithDenied = StandardAwsKmsClientSuppliers
+ .denyRegionsBuilder(Collections.singleton(REGION_1))
+ .baseClientSupplier(new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder).build()).build();
+
+ when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertThrows(UnsupportedRegionException.class, () -> supplierWithDenied.getClient(REGION_1));
+ assertNotNull(supplierWithDenied.getClient(REGION_2));
+ }
}