From dd4ffda8821f2e0fe88492574a4cd2654e6c8b4a Mon Sep 17 00:00:00 2001 From: Wesley Rosenblum Date: Wed, 11 Mar 2020 12:53:54 -0700 Subject: [PATCH 1/3] Make client suppliers composable --- .../MultiRegionRecordPusherExample.java | 9 ++- .../keyrings/StandardKeyrings.java | 10 +-- .../kms/AllowRegionsAwsKmsClientSupplier.java | 67 +++++++++++++++++++ .../kms/AwsKmsClientSupplier.java | 63 ++--------------- .../kms/DenyRegionsAwsKmsClientSupplier.java | 67 +++++++++++++++++++ .../AllowRegionsAwsKmsClientSupplierTest.java | 58 ++++++++++++++++ .../kms/AwsKmsClientSupplierTest.java | 59 ---------------- .../DenyRegionsAwsKmsClientSupplierTest.java | 58 ++++++++++++++++ 8 files changed, 263 insertions(+), 128 deletions(-) create mode 100644 src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java create mode 100644 src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java create mode 100644 src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java create mode 100644 src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java diff --git a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java index d7d7e071f..468756969 100644 --- a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java +++ b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java @@ -13,7 +13,6 @@ package com.amazonaws.crypto.examples.datakeycaching; -import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.AwsCryptoResult; @@ -22,6 +21,7 @@ import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache; import com.amazonaws.encryptionsdk.keyrings.Keyring; import com.amazonaws.encryptionsdk.keyrings.StandardKeyrings; +import com.amazonaws.encryptionsdk.kms.AllowRegionsAwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; import com.amazonaws.regions.Region; @@ -32,7 +32,6 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -73,10 +72,10 @@ public MultiRegionRecordPusherExample(final Region[] regions, final String kmsAl .build()); keyrings.add(StandardKeyrings.awsKmsBuilder() - .awsKmsClientSupplier(AwsKmsClientSupplier.builder() + .awsKmsClientSupplier(new AllowRegionsAwsKmsClientSupplier(Collections.singleton(region.getName()), + AwsKmsClientSupplier.builder() .credentialsProvider(credentialsProvider) - .allowedRegions(Collections.singleton(region.getName())) - .build()) + .build())) .generatorKeyId(AwsKmsCmkId.fromString(kmsAliasName)).build()); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java index 9cf37ea12..63379d62d 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java @@ -13,8 +13,9 @@ package com.amazonaws.encryptionsdk.keyrings; -import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; +import com.amazonaws.encryptionsdk.kms.AllowRegionsAwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; +import com.amazonaws.encryptionsdk.kms.DenyRegionsAwsKmsClientSupplier; import java.util.Arrays; import java.util.List; @@ -80,15 +81,14 @@ public static AwsKmsKeyringBuilder awsKmsBuilder() { * AWS KMS Discovery keyrings do not specify any CMKs to decrypt with, and thus will attempt to decrypt * using any encrypted data key in an encrypted message. AWS KMS Discovery keyrings do not perform encryption. *

- * To create an AWS KMS Regional Discovery Keyring, construct an {@link AwsKmsClientSupplier} using - * {@link AwsKmsClientSupplier#builder()} to specify which regions to include/exclude. + * To create an AWS KMS Regional Discovery Keyring, use an {@link AllowRegionsAwsKmsClientSupplier} or a + * {@link DenyRegionsAwsKmsClientSupplier} to specify which regions to include/exclude. *

* For example, to include only CMKs in the us-east-1 region: *
      * StandardKeyrings.awsKmsDiscovery()
      *             .awsKmsClientSupplier(
-     *                     AwsKmsClientSupplier.builder()
-     *                     .allowedRegions(Collections.singleton("us-east-1")).build())
+     *                     new AllowRegionsAwsKmsClientSupplier(Collections.singleton("us-east-1")))
      *             .build();
      * 
* diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java new file mode 100644 index 000000000..3ba4bca84 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java @@ -0,0 +1,67 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.notEmpty; + +/** + * An AwsKmsClientSupplier that will only supply clients for a given set of AWS regions. + */ +public class AllowRegionsAwsKmsClientSupplier implements AwsKmsClientSupplier { + + private final Set allowedRegions; + private final AwsKmsClientSupplier baseSupplier; + + /** + * Constructs a client supplier that only supplies clients for the specified AWS regions. + * + * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for + */ + public AllowRegionsAwsKmsClientSupplier(Set allowedRegions) { + this(allowedRegions, AwsKmsClientSupplier.builder().build()); + } + + /** + * Constructs a client supplier that only supplies clients for the specified AWS regions. + * Client supplying is delegated to the given baseSupplier. + * + * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for + * @param baseSupplier the client supplier that will supply the client if the region is allowed + */ + public AllowRegionsAwsKmsClientSupplier(Set allowedRegions, AwsKmsClientSupplier baseSupplier) { + notEmpty(allowedRegions, "At least one region is required"); + requireNonNull(baseSupplier, "baseSupplier is required"); + this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(allowedRegions)); + this.baseSupplier = baseSupplier; + } + + public AWSKMS getClient(@Nullable String regionId) { + + if (!allowedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is not in the set of allowed regions %s", + regionId, allowedRegions)); + } + + return baseSupplier.getClient(regionId); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java index ceccb5936..fbd792354 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java @@ -24,15 +24,12 @@ import javax.annotation.Nullable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Proxy; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import static java.util.Objects.requireNonNull; -import static org.apache.commons.lang3.Validate.isTrue; -import static org.apache.commons.lang3.Validate.notEmpty; /** * Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The @@ -80,16 +77,14 @@ static AWSKMS getClientByKeyId(AwsKmsCmkId keyId, AwsKmsClientSupplier clientSup } /** - * Builder to construct an AwsKmsClientSupplier given various - * optional settings. + * Builder to construct an AwsKmsClientSupplier that will create and cache clients + * for any region. CredentialProvider and ClientConfiguration are optional and may + * be configured if necessary. */ class Builder { private AWSCredentialsProvider credentialsProvider; private ClientConfiguration clientConfiguration; - private Set allowedRegions = Collections.emptySet(); - private Set excludedRegions = Collections.emptySet(); - private boolean clientCachingEnabled = true; private final Map clientsCache = new HashMap<>(); private static final Set AWSKMS_METHODS = new HashSet<>(); private AWSKMSClientBuilder awsKmsClientBuilder; @@ -105,19 +100,8 @@ class Builder { } public AwsKmsClientSupplier build() { - isTrue(allowedRegions.isEmpty() || excludedRegions.isEmpty(), - "Either allowed regions or excluded regions may be set, not both."); return regionId -> { - if (!allowedRegions.isEmpty() && !allowedRegions.contains(regionId)) { - throw new UnsupportedRegionException(String.format("Region %s is not in the list of allowed regions %s", - regionId, allowedRegions)); - } - - if (excludedRegions.contains(regionId)) { - throw new UnsupportedRegionException(String.format("Region %s is in the list of excluded regions %s", - regionId, excludedRegions)); - } if (clientsCache.containsKey(regionId)) { return clientsCache.get(regionId); @@ -135,13 +119,7 @@ public AwsKmsClientSupplier build() { awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId); } - AWSKMS client = awsKmsClientBuilder.build(); - - if (clientCachingEnabled) { - client = newCachingProxy(client, regionId); - } - - return client; + return newCachingProxy(awsKmsClientBuilder.build(), regionId); }; } @@ -165,39 +143,6 @@ public Builder clientConfiguration(ClientConfiguration clientConfiguration) { return this; } - /** - * Sets the AWS regions that the client supplier is permitted to use. - * - * @param regions The set of regions. - */ - public Builder allowedRegions(Set regions) { - notEmpty(regions, "At least one region is required"); - this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(regions)); - return this; - } - - /** - * Sets the AWS regions that the client supplier is not permitted to use. - * - * @param regions The set of regions. - */ - public Builder excludedRegions(Set regions) { - requireNonNull(regions, "regions is required"); - this.excludedRegions = Collections.unmodifiableSet(new HashSet<>(regions)); - return this; - } - - /** - * When set to false, disables the AWSKMS client for each region from being cached and reused. - * By default, client caching is enabled. - * - * @param enabled Whether or not caching is enabled. - */ - public Builder clientCaching(boolean enabled) { - this.clientCachingEnabled = enabled; - return this; - } - /** * Creates a proxy for the AWSKMS client that will populate the client into the client cache * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java new file mode 100644 index 000000000..4df614fe1 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java @@ -0,0 +1,67 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.notEmpty; + +/** + * A client supplier that supplies clients for any region except the specified AWS regions. + */ +public class DenyRegionsAwsKmsClientSupplier implements AwsKmsClientSupplier { + + private final Set deniedRegions; + private final AwsKmsClientSupplier baseSupplier; + + /** + * Constructs a client supplier that supplies clients for any region except the specified AWS regions. + * + * @param deniedRegions the AWS regions that the client supplier will not supply clients for + */ + public DenyRegionsAwsKmsClientSupplier(Set deniedRegions) { + this(deniedRegions, AwsKmsClientSupplier.builder().build()); + } + + /** + * Constructs a client supplier that supplies clients for any region except the specified AWS regions. + * Client supplying is delegated to the given baseSupplier. + * + * @param deniedRegions the AWS regions that the client supplier will not supply clients for + * @param baseSupplier the client supplier that will supply the client if the region is not denied + */ + public DenyRegionsAwsKmsClientSupplier(Set deniedRegions, AwsKmsClientSupplier baseSupplier) { + notEmpty(deniedRegions, "At least one region is required"); + requireNonNull(baseSupplier, "baseSupplier is required"); + this.deniedRegions = Collections.unmodifiableSet(new HashSet<>(deniedRegions)); + this.baseSupplier = baseSupplier; + } + + public AWSKMS getClient(@Nullable String regionId) { + + if (deniedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is in the set of denied regions %s", + regionId, deniedRegions)); + } + + return baseSupplier.getClient(regionId); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java new file mode 100644 index 000000000..62bb0ff57 --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java @@ -0,0 +1,58 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClientBuilder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class AllowRegionsAwsKmsClientSupplierTest { + + @Mock AWSKMSClientBuilder kmsClientBuilder; + @Mock AWSKMS awskms; + private static final String REGION_1 = "us-east-1"; + private static final String REGION_2 = "us-west-2"; + + @Test + void testAllowedRegions() { + AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + AwsKmsClientSupplier supplierWithAllowed = new AllowRegionsAwsKmsClientSupplier( + Collections.singleton(REGION_1), + new AwsKmsClientSupplier.Builder(kmsClientBuilder).build()); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithAllowed.getClient(REGION_1)); + assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java index 748077d40..d65ef1f4b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java @@ -15,7 +15,6 @@ import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; import com.amazonaws.services.kms.model.AWSKMSException; @@ -29,10 +28,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import java.util.Collections; - import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -70,61 +66,6 @@ void testCredentialsAndClientConfiguration() { verify(kmsClientBuilder).build(); } - @Test - void testAllowedAndExcludedRegions() { - AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); - - AwsKmsClientSupplier supplierWithAllowed = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .allowedRegions(Collections.singleton(REGION_1)) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithAllowed.getClient(REGION_1)); - assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); - - AwsKmsClientSupplier supplierWithExcluded = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .excludedRegions(Collections.singleton(REGION_1)) - .build(); - - when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertThrows(UnsupportedRegionException.class, () -> supplierWithExcluded.getClient(REGION_1)); - assertNotNull(supplierWithExcluded.getClient(REGION_2)); - - assertThrows(IllegalArgumentException.class, () -> new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .allowedRegions(Collections.singleton(REGION_1)) - .excludedRegions(Collections.singleton(REGION_2)) - .build()); - } - - @Test - void testClientCachingDisabled() { - AwsKmsClientSupplier supplierCachingDisabled = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .clientCaching(false) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - AWSKMS uncachedClient = supplierCachingDisabled.getClient(REGION_1); - verify(kmsClientBuilder, times(1)).build(); - - when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult()); - - uncachedClient.encrypt(encryptRequest); - supplierCachingDisabled.getClient(REGION_1); - verify(kmsClientBuilder, times(2)).build(); - } - @Test void testClientCaching() { AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder) diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java new file mode 100644 index 000000000..bcd3860de --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java @@ -0,0 +1,58 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClientBuilder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class DenyRegionsAwsKmsClientSupplierTest { + + @Mock AWSKMSClientBuilder kmsClientBuilder; + @Mock AWSKMS awskms; + private static final String REGION_1 = "us-east-1"; + private static final String REGION_2 = "us-west-2"; + + @Test + void testDeniedRegions() { + AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + AwsKmsClientSupplier supplierWithDenied = new DenyRegionsAwsKmsClientSupplier( + Collections.singleton(REGION_1), + new AwsKmsClientSupplier.Builder(kmsClientBuilder).build()); + + when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertThrows(UnsupportedRegionException.class, () -> supplierWithDenied.getClient(REGION_1)); + assertNotNull(supplierWithDenied.getClient(REGION_2)); + } +} From 010840c4348e6069e393b42cb6ef29b5cbf5a748 Mon Sep 17 00:00:00 2001 From: Wesley Rosenblum Date: Wed, 11 Mar 2020 17:55:49 -0700 Subject: [PATCH 2/3] Refactor to move suppliers to StandardAwsKmsClientSuppliers class --- .../MultiRegionRecordPusherExample.java | 11 +- .../keyrings/AwsKmsKeyringBuilder.java | 3 +- .../keyrings/StandardKeyrings.java | 9 +- .../kms/AllowRegionsAwsKmsClientSupplier.java | 67 ----- .../kms/AwsKmsClientSupplier.java | 119 --------- .../kms/DenyRegionsAwsKmsClientSupplier.java | 67 ----- .../kms/StandardAwsKmsClientSuppliers.java | 252 ++++++++++++++++++ .../encryptionsdk/TestVectorRunner.java | 3 +- .../AllowRegionsAwsKmsClientSupplierTest.java | 58 ---- .../DenyRegionsAwsKmsClientSupplierTest.java | 58 ---- ...=> StandardAwsKmsClientSuppliersTest.java} | 53 +++- 11 files changed, 315 insertions(+), 385 deletions(-) delete mode 100644 src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java delete mode 100644 src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java create mode 100644 src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java delete mode 100644 src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java delete mode 100644 src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java rename src/test/java/com/amazonaws/encryptionsdk/kms/{AwsKmsClientSupplierTest.java => StandardAwsKmsClientSuppliersTest.java} (69%) diff --git a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java index 468756969..62ff53d45 100644 --- a/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java +++ b/src/examples/java/com/amazonaws/crypto/examples/datakeycaching/MultiRegionRecordPusherExample.java @@ -21,9 +21,8 @@ import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache; import com.amazonaws.encryptionsdk.keyrings.Keyring; import com.amazonaws.encryptionsdk.keyrings.StandardKeyrings; -import com.amazonaws.encryptionsdk.kms.AllowRegionsAwsKmsClientSupplier; -import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import com.amazonaws.regions.Region; import com.amazonaws.services.kinesis.AmazonKinesis; import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; @@ -72,10 +71,10 @@ public MultiRegionRecordPusherExample(final Region[] regions, final String kmsAl .build()); keyrings.add(StandardKeyrings.awsKmsBuilder() - .awsKmsClientSupplier(new AllowRegionsAwsKmsClientSupplier(Collections.singleton(region.getName()), - AwsKmsClientSupplier.builder() - .credentialsProvider(credentialsProvider) - .build())) + .awsKmsClientSupplier(StandardAwsKmsClientSuppliers + .allowRegionsBuilder(Collections.singleton(region.getName())) + .baseClientSupplier(StandardAwsKmsClientSuppliers.defaultBuilder() + .credentialsProvider(credentialsProvider).build()).build()) .generatorKeyId(AwsKmsCmkId.fromString(kmsAliasName)).build()); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java index 483a13506..5a67334e8 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/AwsKmsKeyringBuilder.java @@ -16,6 +16,7 @@ import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import java.util.List; @@ -111,7 +112,7 @@ public AwsKmsKeyringBuilder generatorKeyId(AwsKmsCmkId generatorKeyId) { */ public Keyring build() { if (awsKmsClientSupplier == null) { - awsKmsClientSupplier = AwsKmsClientSupplier.builder().build(); + awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build(); } return new AwsKmsKeyring(DataKeyEncryptionDao.awsKms(awsKmsClientSupplier, grantTokens), diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java index 63379d62d..1a3fc16f6 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java @@ -13,9 +13,8 @@ package com.amazonaws.encryptionsdk.keyrings; -import com.amazonaws.encryptionsdk.kms.AllowRegionsAwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; -import com.amazonaws.encryptionsdk.kms.DenyRegionsAwsKmsClientSupplier; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import java.util.Arrays; import java.util.List; @@ -81,14 +80,14 @@ public static AwsKmsKeyringBuilder awsKmsBuilder() { * AWS KMS Discovery keyrings do not specify any CMKs to decrypt with, and thus will attempt to decrypt * using any encrypted data key in an encrypted message. AWS KMS Discovery keyrings do not perform encryption. *

- * To create an AWS KMS Regional Discovery Keyring, use an {@link AllowRegionsAwsKmsClientSupplier} or a - * {@link DenyRegionsAwsKmsClientSupplier} to specify which regions to include/exclude. + * To create an AWS KMS Regional Discovery Keyring, use {@link StandardAwsKmsClientSuppliers#allowRegionsBuilder} or + * {@link StandardAwsKmsClientSuppliers#denyRegionsBuilder} to specify which regions to include/exclude. *

* For example, to include only CMKs in the us-east-1 region: *
      * StandardKeyrings.awsKmsDiscovery()
      *             .awsKmsClientSupplier(
-     *                     new AllowRegionsAwsKmsClientSupplier(Collections.singleton("us-east-1")))
+     *                     StandardAwsKmsClientSuppliers.allowRegionsBuilder(Collections.singleton("us-east-1")).build()
      *             .build();
      * 
* diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java deleted file mode 100644 index 3ba4bca84..000000000 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplier.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except - * in compliance with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package com.amazonaws.encryptionsdk.kms; - -import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; -import com.amazonaws.services.kms.AWSKMS; - -import javax.annotation.Nullable; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - -import static java.util.Objects.requireNonNull; -import static org.apache.commons.lang3.Validate.notEmpty; - -/** - * An AwsKmsClientSupplier that will only supply clients for a given set of AWS regions. - */ -public class AllowRegionsAwsKmsClientSupplier implements AwsKmsClientSupplier { - - private final Set allowedRegions; - private final AwsKmsClientSupplier baseSupplier; - - /** - * Constructs a client supplier that only supplies clients for the specified AWS regions. - * - * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for - */ - public AllowRegionsAwsKmsClientSupplier(Set allowedRegions) { - this(allowedRegions, AwsKmsClientSupplier.builder().build()); - } - - /** - * Constructs a client supplier that only supplies clients for the specified AWS regions. - * Client supplying is delegated to the given baseSupplier. - * - * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for - * @param baseSupplier the client supplier that will supply the client if the region is allowed - */ - public AllowRegionsAwsKmsClientSupplier(Set allowedRegions, AwsKmsClientSupplier baseSupplier) { - notEmpty(allowedRegions, "At least one region is required"); - requireNonNull(baseSupplier, "baseSupplier is required"); - this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(allowedRegions)); - this.baseSupplier = baseSupplier; - } - - public AWSKMS getClient(@Nullable String regionId) { - - if (!allowedRegions.contains(regionId)) { - throw new UnsupportedRegionException(String.format("Region %s is not in the set of allowed regions %s", - regionId, allowedRegions)); - } - - return baseSupplier.getClient(regionId); - } -} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java index fbd792354..92a0bfee1 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplier.java @@ -13,21 +13,11 @@ package com.amazonaws.encryptionsdk.kms; -import com.amazonaws.ClientConfiguration; import com.amazonaws.arn.Arn; -import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.AWSKMSClientBuilder; -import com.amazonaws.services.kms.model.AWSKMSException; import javax.annotation.Nullable; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Proxy; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; import static java.util.Objects.requireNonNull; @@ -48,15 +38,6 @@ public interface AwsKmsClientSupplier { */ AWSKMS getClient(@Nullable String regionId) throws UnsupportedRegionException; - /** - * Gets a Builder for constructing an AwsKmsClientSupplier - * - * @return The builder - */ - static Builder builder() { - return new Builder(AWSKMSClientBuilder.standard()); - } - /** * Parses region from the given key id (if possible) and passes that region to the * given clientSupplier to produce an {@code AWSKMS} client. @@ -75,104 +56,4 @@ static AWSKMS getClientByKeyId(AwsKmsCmkId keyId, AwsKmsClientSupplier clientSup return clientSupplier.getClient(null); } - - /** - * Builder to construct an AwsKmsClientSupplier that will create and cache clients - * for any region. CredentialProvider and ClientConfiguration are optional and may - * be configured if necessary. - */ - class Builder { - - private AWSCredentialsProvider credentialsProvider; - private ClientConfiguration clientConfiguration; - private final Map clientsCache = new HashMap<>(); - private static final Set AWSKMS_METHODS = new HashSet<>(); - private AWSKMSClientBuilder awsKmsClientBuilder; - - static { - AWSKMS_METHODS.add("generateDataKey"); - AWSKMS_METHODS.add("encrypt"); - AWSKMS_METHODS.add("decrypt"); - } - - Builder(AWSKMSClientBuilder awsKmsClientBuilder) { - this.awsKmsClientBuilder = awsKmsClientBuilder; - } - - public AwsKmsClientSupplier build() { - - return regionId -> { - - if (clientsCache.containsKey(regionId)) { - return clientsCache.get(regionId); - } - - if (credentialsProvider != null) { - awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider); - } - - if (clientConfiguration != null) { - awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration); - } - - if (regionId != null) { - awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId); - } - - return newCachingProxy(awsKmsClientBuilder.build(), regionId); - }; - } - - /** - * Sets the AWSCredentialsProvider used by the client. - * - * @param credentialsProvider New AWSCredentialsProvider to use. - */ - public Builder credentialsProvider(AWSCredentialsProvider credentialsProvider) { - this.credentialsProvider = credentialsProvider; - return this; - } - - /** - * Sets the ClientConfiguration to be used by the client. - * - * @param clientConfiguration Custom configuration to use. - */ - public Builder clientConfiguration(ClientConfiguration clientConfiguration) { - this.clientConfiguration = clientConfiguration; - return this; - } - - /** - * Creates a proxy for the AWSKMS client that will populate the client into the client cache - * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a - * a malicious user from causing a local resource DOS by sending ciphertext with a large number - * of spurious regions, thereby filling the cache with regions and exhausting resources. - * - * @param client The client to proxy - * @param regionId The region the client is associated with - * @return The proxy - */ - private AWSKMS newCachingProxy(AWSKMS client, String regionId) { - return (AWSKMS) Proxy.newProxyInstance( - AWSKMS.class.getClassLoader(), - new Class[]{AWSKMS.class}, - (proxy, method, methodArgs) -> { - try { - final Object result = method.invoke(client, methodArgs); - if (AWSKMS_METHODS.contains(method.getName())) { - clientsCache.put(regionId, client); - } - return result; - } catch (InvocationTargetException e) { - if (e.getTargetException() instanceof AWSKMSException && - AWSKMS_METHODS.contains(method.getName())) { - clientsCache.put(regionId, client); - } - - throw e.getTargetException(); - } - }); - } - } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java deleted file mode 100644 index 4df614fe1..000000000 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplier.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except - * in compliance with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package com.amazonaws.encryptionsdk.kms; - -import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; -import com.amazonaws.services.kms.AWSKMS; - -import javax.annotation.Nullable; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - -import static java.util.Objects.requireNonNull; -import static org.apache.commons.lang3.Validate.notEmpty; - -/** - * A client supplier that supplies clients for any region except the specified AWS regions. - */ -public class DenyRegionsAwsKmsClientSupplier implements AwsKmsClientSupplier { - - private final Set deniedRegions; - private final AwsKmsClientSupplier baseSupplier; - - /** - * Constructs a client supplier that supplies clients for any region except the specified AWS regions. - * - * @param deniedRegions the AWS regions that the client supplier will not supply clients for - */ - public DenyRegionsAwsKmsClientSupplier(Set deniedRegions) { - this(deniedRegions, AwsKmsClientSupplier.builder().build()); - } - - /** - * Constructs a client supplier that supplies clients for any region except the specified AWS regions. - * Client supplying is delegated to the given baseSupplier. - * - * @param deniedRegions the AWS regions that the client supplier will not supply clients for - * @param baseSupplier the client supplier that will supply the client if the region is not denied - */ - public DenyRegionsAwsKmsClientSupplier(Set deniedRegions, AwsKmsClientSupplier baseSupplier) { - notEmpty(deniedRegions, "At least one region is required"); - requireNonNull(baseSupplier, "baseSupplier is required"); - this.deniedRegions = Collections.unmodifiableSet(new HashSet<>(deniedRegions)); - this.baseSupplier = baseSupplier; - } - - public AWSKMS getClient(@Nullable String regionId) { - - if (deniedRegions.contains(regionId)) { - throw new UnsupportedRegionException(String.format("Region %s is in the set of denied regions %s", - regionId, deniedRegions)); - } - - return baseSupplier.getClient(regionId); - } -} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java new file mode 100644 index 000000000..f92de0101 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java @@ -0,0 +1,252 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClientBuilder; +import com.amazonaws.services.kms.model.AWSKMSException; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Proxy; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.notEmpty; + +/** + * Factory methods for instantiating the standard {@code AwsKmsClientSupplier}s provided by the AWS Encryption SDK. + */ +public class StandardAwsKmsClientSuppliers { + + /** + * A builder to construct the default AwsKmsClientSupplier that will create and cache clients + * for any region. Credentials and client configuration may be specified if necessary. + * + * @return The builder + */ + public static DefaultAwsKmsClientSupplierBuilder defaultBuilder() { + return new DefaultAwsKmsClientSupplierBuilder(AWSKMSClientBuilder.standard()); + } + + /** + * A builder to construct an AwsKmsClientSupplier that will + * only supply clients for a given set of AWS regions. + * + * @param allowedRegions the AWS regions that the client supplier is allowed to supply clients for + * @return The builder + */ + public static AllowRegionsAwsKmsClientSupplierBuilder allowRegionsBuilder(Set allowedRegions) { + return new AllowRegionsAwsKmsClientSupplierBuilder(allowedRegions); + } + + /** + * A builder to construct an AwsKmsClientSupplier that will + * supply clients for all AWS regions except the given set of regions. + * + * @param deniedRegions the AWS regions that the client supplier will not supply clients for + * @return The builder + */ + public static DenyRegionsAwsKmsClientSupplierBuilder denyRegionsBuilder(Set deniedRegions) { + return new DenyRegionsAwsKmsClientSupplierBuilder(deniedRegions); + } + + + /** + * Builder to construct an AwsKmsClientSupplier that will create and cache clients + * for any region. CredentialProvider and ClientConfiguration are optional and may + * be configured if necessary. + */ + public static class DefaultAwsKmsClientSupplierBuilder { + + private AWSCredentialsProvider credentialsProvider; + private ClientConfiguration clientConfiguration; + private final Map clientsCache = new HashMap<>(); + private static final Set AWSKMS_METHODS = new HashSet<>(); + private AWSKMSClientBuilder awsKmsClientBuilder; + + static { + AWSKMS_METHODS.add("generateDataKey"); + AWSKMS_METHODS.add("encrypt"); + AWSKMS_METHODS.add("decrypt"); + } + + DefaultAwsKmsClientSupplierBuilder(AWSKMSClientBuilder awsKmsClientBuilder) { + this.awsKmsClientBuilder = awsKmsClientBuilder; + } + + public AwsKmsClientSupplier build() { + + return regionId -> { + + if (clientsCache.containsKey(regionId)) { + return clientsCache.get(regionId); + } + + if (credentialsProvider != null) { + awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider); + } + + if (clientConfiguration != null) { + awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration); + } + + if (regionId != null) { + awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId); + } + + return newCachingProxy(awsKmsClientBuilder.build(), regionId); + }; + } + + /** + * Sets the AWSCredentialsProvider used by the client. + * + * @param credentialsProvider New AWSCredentialsProvider to use. + */ + public DefaultAwsKmsClientSupplierBuilder credentialsProvider(AWSCredentialsProvider credentialsProvider) { + this.credentialsProvider = credentialsProvider; + return this; + } + + /** + * Sets the ClientConfiguration to be used by the client. + * + * @param clientConfiguration Custom configuration to use. + */ + public DefaultAwsKmsClientSupplierBuilder clientConfiguration(ClientConfiguration clientConfiguration) { + this.clientConfiguration = clientConfiguration; + return this; + } + + /** + * Creates a proxy for the AWSKMS client that will populate the client into the client cache + * after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a + * a malicious user from causing a local resource DOS by sending ciphertext with a large number + * of spurious regions, thereby filling the cache with regions and exhausting resources. + * + * @param client The client to proxy + * @param regionId The region the client is associated with + * @return The proxy + */ + private AWSKMS newCachingProxy(AWSKMS client, String regionId) { + return (AWSKMS) Proxy.newProxyInstance( + AWSKMS.class.getClassLoader(), + new Class[]{AWSKMS.class}, + (proxy, method, methodArgs) -> { + try { + final Object result = method.invoke(client, methodArgs); + if (AWSKMS_METHODS.contains(method.getName())) { + clientsCache.put(regionId, client); + } + return result; + } catch (InvocationTargetException e) { + if (e.getTargetException() instanceof AWSKMSException && + AWSKMS_METHODS.contains(method.getName())) { + clientsCache.put(regionId, client); + } + + throw e.getTargetException(); + } + }); + } + } + + /** + * An AwsKmsClientSupplier that will only supply clients for a given set of AWS regions. + */ + public static class AllowRegionsAwsKmsClientSupplierBuilder { + + private final Set allowedRegions; + private AwsKmsClientSupplier baseClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build(); + + private AllowRegionsAwsKmsClientSupplierBuilder(Set allowedRegions) { + notEmpty(allowedRegions, "At least one region is required"); + requireNonNull(baseClientSupplier, "baseClientSupplier is required"); + + this.allowedRegions = allowedRegions; + } + + /** + * Constructs the AwsKmsClientSupplier. + * + * @return The AwsKmsClientSupplier + */ + public AwsKmsClientSupplier build() { + return regionId -> { + + if (!allowedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is not in the set of allowed regions %s", + regionId, allowedRegions)); + } + + return baseClientSupplier.getClient(regionId); + }; + } + + /** + * Sets the client supplier that will supply the client if the region is allowed. + * + * @param baseClientSupplier the client supplier that will supply the client if the region is allowed. + */ + public AllowRegionsAwsKmsClientSupplierBuilder baseClientSupplier(AwsKmsClientSupplier baseClientSupplier) { + this.baseClientSupplier = baseClientSupplier; + return this; + } + } + + /** + * A client supplier that supplies clients for any region except the specified AWS regions. + */ + public static class DenyRegionsAwsKmsClientSupplierBuilder { + + private final Set deniedRegions; + private AwsKmsClientSupplier baseClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build(); + + private DenyRegionsAwsKmsClientSupplierBuilder(Set deniedRegions) { + notEmpty(deniedRegions, "At least one region is required"); + requireNonNull(baseClientSupplier, "baseClientSupplier is required"); + + this.deniedRegions = deniedRegions; + } + + /** + * Sets the client supplier that will supply the client if the region is allowed. + * + * @param baseClientSupplier the client supplier that will supply the client if the region is allowed. + */ + public DenyRegionsAwsKmsClientSupplierBuilder baseClientSupplier(AwsKmsClientSupplier baseClientSupplier) { + this.baseClientSupplier = baseClientSupplier; + return this; + } + + public AwsKmsClientSupplier build() { + + return regionId -> { + + if (deniedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is in the set of denied regions %s", + regionId, deniedRegions)); + } + + return baseClientSupplier.getClient(regionId); + }; + } + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java index 078dff781..89f21479b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java +++ b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java @@ -21,6 +21,7 @@ import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier; import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers; import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; import com.amazonaws.util.IOUtils; import com.fasterxml.jackson.core.type.TypeReference; @@ -64,7 +65,7 @@ class TestVectorRunner { // We save the files in memory to avoid repeatedly retrieving them. // This won't work if the plaintexts are too large or numerous private static final Map cachedData = new HashMap<>(); - private static final AwsKmsClientSupplier awsKmsClientSupplier = AwsKmsClientSupplier.builder() + private static final AwsKmsClientSupplier awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder() .credentialsProvider(new DefaultAWSCredentialsProviderChain()) .build(); private static final KmsMasterKeyProvider kmsProv = KmsMasterKeyProvider diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java deleted file mode 100644 index 62bb0ff57..000000000 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/AllowRegionsAwsKmsClientSupplierTest.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except - * in compliance with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package com.amazonaws.encryptionsdk.kms; - -import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; -import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.AWSKMSClientBuilder; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import java.util.Collections; - -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class AllowRegionsAwsKmsClientSupplierTest { - - @Mock AWSKMSClientBuilder kmsClientBuilder; - @Mock AWSKMS awskms; - private static final String REGION_1 = "us-east-1"; - private static final String REGION_2 = "us-west-2"; - - @Test - void testAllowedRegions() { - AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); - - AwsKmsClientSupplier supplierWithAllowed = new AllowRegionsAwsKmsClientSupplier( - Collections.singleton(REGION_1), - new AwsKmsClientSupplier.Builder(kmsClientBuilder).build()); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithAllowed.getClient(REGION_1)); - assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); - } -} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java deleted file mode 100644 index bcd3860de..000000000 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/DenyRegionsAwsKmsClientSupplierTest.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except - * in compliance with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package com.amazonaws.encryptionsdk.kms; - -import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; -import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.AWSKMSClientBuilder; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import java.util.Collections; - -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class DenyRegionsAwsKmsClientSupplierTest { - - @Mock AWSKMSClientBuilder kmsClientBuilder; - @Mock AWSKMS awskms; - private static final String REGION_1 = "us-east-1"; - private static final String REGION_2 = "us-west-2"; - - @Test - void testDeniedRegions() { - AwsKmsClientSupplier supplierWithDefaultValues = new AwsKmsClientSupplier.Builder(kmsClientBuilder) - .build(); - - when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); - - AwsKmsClientSupplier supplierWithDenied = new DenyRegionsAwsKmsClientSupplier( - Collections.singleton(REGION_1), - new AwsKmsClientSupplier.Builder(kmsClientBuilder).build()); - - when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); - when(kmsClientBuilder.build()).thenReturn(awskms); - - assertThrows(UnsupportedRegionException.class, () -> supplierWithDenied.getClient(REGION_1)); - assertNotNull(supplierWithDenied.getClient(REGION_2)); - } -} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java similarity index 69% rename from src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java rename to src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java index d65ef1f4b..dec01d5a3 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsClientSupplierTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliersTest.java @@ -15,6 +15,8 @@ import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers.DefaultAwsKmsClientSupplierBuilder; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; import com.amazonaws.services.kms.model.AWSKMSException; @@ -28,14 +30,17 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import java.util.Collections; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) -class AwsKmsClientSupplierTest { +class StandardAwsKmsClientSuppliersTest { @Mock AWSKMSClientBuilder kmsClientBuilder; @Mock AWSKMS awskms; @@ -54,7 +59,7 @@ void testCredentialsAndClientConfiguration() { when(kmsClientBuilder.withCredentials(credentialsProvider)).thenReturn(kmsClientBuilder); when(kmsClientBuilder.build()).thenReturn(awskms); - AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder) + AwsKmsClientSupplier supplier = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) .credentialsProvider(credentialsProvider) .clientConfiguration(clientConfiguration) .build(); @@ -68,7 +73,7 @@ void testCredentialsAndClientConfiguration() { @Test void testClientCaching() { - AwsKmsClientSupplier supplier = new AwsKmsClientSupplier.Builder(kmsClientBuilder) + AwsKmsClientSupplier supplier = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) .build(); when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); @@ -125,4 +130,46 @@ void testGetClientByKeyId() { assertEquals(awskms, AwsKmsClientSupplier.getClientByKeyId(AwsKmsCmkId.fromString(alias), s -> awskms)); assertEquals(awskms, AwsKmsClientSupplier.getClientByKeyId(AwsKmsCmkId.fromString(keyId), s -> awskms)); } + + @Test + void testAllowedRegions() { + AwsKmsClientSupplier supplierWithDefaultValues = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + AwsKmsClientSupplier supplierWithAllowed = StandardAwsKmsClientSuppliers + .allowRegionsBuilder(Collections.singleton(REGION_1)) + .baseClientSupplier(new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder).build()).build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithAllowed.getClient(REGION_1)); + assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); + } + + @Test + void testDeniedRegions() { + AwsKmsClientSupplier supplierWithDefaultValues = new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + AwsKmsClientSupplier supplierWithDenied = StandardAwsKmsClientSuppliers + .denyRegionsBuilder(Collections.singleton(REGION_1)) + .baseClientSupplier(new DefaultAwsKmsClientSupplierBuilder(kmsClientBuilder).build()).build(); + + when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertThrows(UnsupportedRegionException.class, () -> supplierWithDenied.getClient(REGION_1)); + assertNotNull(supplierWithDenied.getClient(REGION_2)); + } } From b8be12f4a0b56da6be47de8faa5eff5cd3662612 Mon Sep 17 00:00:00 2001 From: Wesley Rosenblum Date: Mon, 16 Mar 2020 12:41:21 -0700 Subject: [PATCH 3/3] Using ConcurrentHashMap for the client cache to be thread safe --- .../kms/StandardAwsKmsClientSuppliers.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java index f92de0101..24b71c45c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/StandardAwsKmsClientSuppliers.java @@ -22,10 +22,10 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Proxy; -import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import static java.util.Objects.requireNonNull; import static org.apache.commons.lang3.Validate.notEmpty; @@ -77,9 +77,10 @@ public static class DefaultAwsKmsClientSupplierBuilder { private AWSCredentialsProvider credentialsProvider; private ClientConfiguration clientConfiguration; - private final Map clientsCache = new HashMap<>(); + private final Map clientsCache = new ConcurrentHashMap<>(); private static final Set AWSKMS_METHODS = new HashSet<>(); private AWSKMSClientBuilder awsKmsClientBuilder; + private static final String NULL_REGION = "null-region"; static { AWSKMS_METHODS.add("generateDataKey"); @@ -95,6 +96,10 @@ public AwsKmsClientSupplier build() { return regionId -> { + if(regionId == null) { + regionId = NULL_REGION; + } + if (clientsCache.containsKey(regionId)) { return clientsCache.get(regionId); } @@ -107,7 +112,7 @@ public AwsKmsClientSupplier build() { awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration); } - if (regionId != null) { + if (!regionId.equals(NULL_REGION)) { awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId); }