Skip to content

Making client suppliers composable #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

package com.amazonaws.crypto.examples.datakeycaching;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.encryptionsdk.AwsCrypto;
import com.amazonaws.encryptionsdk.AwsCryptoResult;
Expand All @@ -22,8 +21,8 @@
import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache;
import com.amazonaws.encryptionsdk.keyrings.Keyring;
import com.amazonaws.encryptionsdk.keyrings.StandardKeyrings;
import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;
import com.amazonaws.regions.Region;
import com.amazonaws.services.kinesis.AmazonKinesis;
import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder;
Expand All @@ -32,7 +31,6 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -73,10 +71,10 @@ public MultiRegionRecordPusherExample(final Region[] regions, final String kmsAl
.build());

keyrings.add(StandardKeyrings.awsKmsBuilder()
.awsKmsClientSupplier(AwsKmsClientSupplier.builder()
.credentialsProvider(credentialsProvider)
.allowedRegions(Collections.singleton(region.getName()))
.build())
.awsKmsClientSupplier(StandardAwsKmsClientSuppliers
.allowRegionsBuilder(Collections.singleton(region.getName()))
.baseClientSupplier(StandardAwsKmsClientSuppliers.defaultBuilder()
.credentialsProvider(credentialsProvider).build()).build())
.generatorKeyId(AwsKmsCmkId.fromString(kmsAliasName)).build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;

import java.util.List;

Expand Down Expand Up @@ -111,7 +112,7 @@ public AwsKmsKeyringBuilder generatorKeyId(AwsKmsCmkId generatorKeyId) {
*/
public Keyring build() {
if (awsKmsClientSupplier == null) {
awsKmsClientSupplier = AwsKmsClientSupplier.builder().build();
awsKmsClientSupplier = StandardAwsKmsClientSuppliers.defaultBuilder().build();
}

return new AwsKmsKeyring(DataKeyEncryptionDao.awsKms(awsKmsClientSupplier, grantTokens),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

package com.amazonaws.encryptionsdk.keyrings;

import com.amazonaws.encryptionsdk.kms.AwsKmsClientSupplier;
import com.amazonaws.encryptionsdk.kms.AwsKmsCmkId;
import com.amazonaws.encryptionsdk.kms.StandardAwsKmsClientSuppliers;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -80,15 +80,14 @@ public static AwsKmsKeyringBuilder awsKmsBuilder() {
* AWS KMS Discovery keyrings do not specify any CMKs to decrypt with, and thus will attempt to decrypt
* using any encrypted data key in an encrypted message. AWS KMS Discovery keyrings do not perform encryption.
* <p></p>
* To create an AWS KMS Regional Discovery Keyring, construct an {@link AwsKmsClientSupplier} using
* {@link AwsKmsClientSupplier#builder()} to specify which regions to include/exclude.
* To create an AWS KMS Regional Discovery Keyring, use {@link StandardAwsKmsClientSuppliers#allowRegionsBuilder} or
* {@link StandardAwsKmsClientSuppliers#denyRegionsBuilder} to specify which regions to include/exclude.
* <p></p>
* For example, to include only CMKs in the us-east-1 region:
* <pre>
* StandardKeyrings.awsKmsDiscovery()
* .awsKmsClientSupplier(
* AwsKmsClientSupplier.builder()
* .allowedRegions(Collections.singleton("us-east-1")).build())
* StandardAwsKmsClientSuppliers.allowRegionsBuilder(Collections.singleton("us-east-1")).build()
* .build();
* </pre>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,13 @@

package com.amazonaws.encryptionsdk.kms;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.arn.Arn;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
import com.amazonaws.services.kms.AWSKMS;
import com.amazonaws.services.kms.AWSKMSClientBuilder;
import com.amazonaws.services.kms.model.AWSKMSException;

import javax.annotation.Nullable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import static java.util.Objects.requireNonNull;
import static org.apache.commons.lang3.Validate.isTrue;
import static org.apache.commons.lang3.Validate.notEmpty;

/**
* Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The
Expand All @@ -51,15 +38,6 @@ public interface AwsKmsClientSupplier {
*/
AWSKMS getClient(@Nullable String regionId) throws UnsupportedRegionException;

/**
* Gets a Builder for constructing an AwsKmsClientSupplier
*
* @return The builder
*/
static Builder builder() {
return new Builder(AWSKMSClientBuilder.standard());
}

/**
* Parses region from the given key id (if possible) and passes that region to the
* given clientSupplier to produce an {@code AWSKMS} client.
Expand All @@ -78,156 +56,4 @@ static AWSKMS getClientByKeyId(AwsKmsCmkId keyId, AwsKmsClientSupplier clientSup

return clientSupplier.getClient(null);
}

/**
* Builder to construct an AwsKmsClientSupplier given various
* optional settings.
*/
class Builder {

private AWSCredentialsProvider credentialsProvider;
private ClientConfiguration clientConfiguration;
private Set<String> allowedRegions = Collections.emptySet();
private Set<String> excludedRegions = Collections.emptySet();
private boolean clientCachingEnabled = true;
private final Map<String, AWSKMS> clientsCache = new HashMap<>();
private static final Set<String> AWSKMS_METHODS = new HashSet<>();
private AWSKMSClientBuilder awsKmsClientBuilder;

static {
AWSKMS_METHODS.add("generateDataKey");
AWSKMS_METHODS.add("encrypt");
AWSKMS_METHODS.add("decrypt");
}

Builder(AWSKMSClientBuilder awsKmsClientBuilder) {
this.awsKmsClientBuilder = awsKmsClientBuilder;
}

public AwsKmsClientSupplier build() {
isTrue(allowedRegions.isEmpty() || excludedRegions.isEmpty(),
"Either allowed regions or excluded regions may be set, not both.");

return regionId -> {
if (!allowedRegions.isEmpty() && !allowedRegions.contains(regionId)) {
throw new UnsupportedRegionException(String.format("Region %s is not in the list of allowed regions %s",
regionId, allowedRegions));
}

if (excludedRegions.contains(regionId)) {
throw new UnsupportedRegionException(String.format("Region %s is in the list of excluded regions %s",
regionId, excludedRegions));
}

if (clientsCache.containsKey(regionId)) {
return clientsCache.get(regionId);
}

if (credentialsProvider != null) {
awsKmsClientBuilder = awsKmsClientBuilder.withCredentials(credentialsProvider);
}

if (clientConfiguration != null) {
awsKmsClientBuilder = awsKmsClientBuilder.withClientConfiguration(clientConfiguration);
}

if (regionId != null) {
awsKmsClientBuilder = awsKmsClientBuilder.withRegion(regionId);
}

AWSKMS client = awsKmsClientBuilder.build();

if (clientCachingEnabled) {
client = newCachingProxy(client, regionId);
}

return client;
};
}

/**
* Sets the AWSCredentialsProvider used by the client.
*
* @param credentialsProvider New AWSCredentialsProvider to use.
*/
public Builder credentialsProvider(AWSCredentialsProvider credentialsProvider) {
this.credentialsProvider = credentialsProvider;
return this;
}

/**
* Sets the ClientConfiguration to be used by the client.
*
* @param clientConfiguration Custom configuration to use.
*/
public Builder clientConfiguration(ClientConfiguration clientConfiguration) {
this.clientConfiguration = clientConfiguration;
return this;
}

/**
* Sets the AWS regions that the client supplier is permitted to use.
*
* @param regions The set of regions.
*/
public Builder allowedRegions(Set<String> regions) {
notEmpty(regions, "At least one region is required");
this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(regions));
return this;
}

/**
* Sets the AWS regions that the client supplier is not permitted to use.
*
* @param regions The set of regions.
*/
public Builder excludedRegions(Set<String> regions) {
requireNonNull(regions, "regions is required");
this.excludedRegions = Collections.unmodifiableSet(new HashSet<>(regions));
return this;
}

/**
* When set to false, disables the AWSKMS client for each region from being cached and reused.
* By default, client caching is enabled.
*
* @param enabled Whether or not caching is enabled.
*/
public Builder clientCaching(boolean enabled) {
this.clientCachingEnabled = enabled;
return this;
}

/**
* Creates a proxy for the AWSKMS client that will populate the client into the client cache
* after an AWS KMS method successfully completes or an AWS KMS exception occurs. This is to prevent a
* a malicious user from causing a local resource DOS by sending ciphertext with a large number
* of spurious regions, thereby filling the cache with regions and exhausting resources.
*
* @param client The client to proxy
* @param regionId The region the client is associated with
* @return The proxy
*/
private AWSKMS newCachingProxy(AWSKMS client, String regionId) {
return (AWSKMS) Proxy.newProxyInstance(
AWSKMS.class.getClassLoader(),
new Class[]{AWSKMS.class},
(proxy, method, methodArgs) -> {
try {
final Object result = method.invoke(client, methodArgs);
if (AWSKMS_METHODS.contains(method.getName())) {
clientsCache.put(regionId, client);
}
return result;
} catch (InvocationTargetException e) {
if (e.getTargetException() instanceof AWSKMSException &&
AWSKMS_METHODS.contains(method.getName())) {
clientsCache.put(regionId, client);
}

throw e.getTargetException();
}
});
}
}
}
Loading