diff --git a/pom.xml b/pom.xml
index c426db5e6..a8774878b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -42,7 +42,7 @@
com.amazonaws
aws-java-sdk
- 1.11.561
+ 1.11.677
true
diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java
new file mode 100644
index 000000000..58f78833c
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.exception;
+
+/**
+ * This exception is thrown when an Amazon Resource Name is provided that does not
+ * match the CMK Alias or ARN format.
+ */
+public class MalformedArnException extends AwsCryptoException {
+
+ private static final long serialVersionUID = -1L;
+
+ public MalformedArnException() {
+ super();
+ }
+
+ public MalformedArnException(final String message) {
+ super(message);
+ }
+
+ public MalformedArnException(final Throwable cause) {
+ super(cause);
+ }
+
+ public MalformedArnException(final String message, final Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java
new file mode 100644
index 000000000..46a062187
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.exception;
+
+/**
+ * This exception is thrown when the key used by KMS to decrypt a data key does not
+ * match the provider information contained within the encrypted data key.
+ */
+public class MismatchedDataKeyException extends AwsCryptoException {
+
+ private static final long serialVersionUID = -1L;
+
+ public MismatchedDataKeyException() {
+ super();
+ }
+
+ public MismatchedDataKeyException(final String message) {
+ super(message);
+ }
+
+ public MismatchedDataKeyException(final Throwable cause) {
+ super(cause);
+ }
+
+ public MismatchedDataKeyException(final String message, final Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedRegionException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedRegionException.java
new file mode 100644
index 000000000..2c06a602c
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedRegionException.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.exception;
+
+/**
+ * This exception is thrown when a region that is not allowed to be used by
+ * a given KmsClientSupplier is specified.
+ */
+public class UnsupportedRegionException extends AwsCryptoException {
+
+ private static final long serialVersionUID = -1L;
+
+ public UnsupportedRegionException() {
+ super();
+ }
+
+ public UnsupportedRegionException(final String message) {
+ super(message);
+ }
+
+ public UnsupportedRegionException(final Throwable cause) {
+ super(cause);
+ }
+
+ public UnsupportedRegionException(final String message, final Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java
new file mode 100644
index 000000000..4b3974501
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java
@@ -0,0 +1,176 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.keyrings;
+
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.GenerateDataKeyResult;
+import com.amazonaws.encryptionsdk.kms.KmsUtils;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.isArnWellFormed;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.unmodifiableList;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * A keyring which interacts with AWS Key Management Service (KMS) to create,
+ * encrypt, and decrypt data keys using KMS defined Customer Master Keys (CMKs).
+ */
+class KmsKeyring implements Keyring {
+
+ private final DataKeyEncryptionDao dataKeyEncryptionDao;
+ private final List keyIds;
+ private final String generatorKeyId;
+ private final boolean isDiscovery;
+
+ KmsKeyring(DataKeyEncryptionDao dataKeyEncryptionDao, List keyIds, String generatorKeyId) {
+ requireNonNull(dataKeyEncryptionDao, "dataKeyEncryptionDao is required");
+ this.dataKeyEncryptionDao = dataKeyEncryptionDao;
+ this.keyIds = keyIds == null ? emptyList() : unmodifiableList(new ArrayList<>(keyIds));
+ this.generatorKeyId = generatorKeyId;
+ this.isDiscovery = this.generatorKeyId == null && this.keyIds.isEmpty();
+
+ if (!this.keyIds.stream().allMatch(KmsUtils::isArnWellFormed)) {
+ throw new MalformedArnException("keyIds must contain only CMK aliases and well formed ARNs");
+ }
+
+ if (generatorKeyId != null) {
+ if (!isArnWellFormed(generatorKeyId)) {
+ throw new MalformedArnException("generatorKeyId must be either a CMK alias or a well formed ARN");
+ }
+ if (this.keyIds.contains(generatorKeyId)) {
+ throw new IllegalArgumentException("KeyIds should not contain the generatorKeyId");
+ }
+ }
+ }
+
+ @Override
+ public void onEncrypt(EncryptionMaterials encryptionMaterials) {
+ requireNonNull(encryptionMaterials, "encryptionMaterials are required");
+
+ // If this keyring is a discovery keyring, OnEncrypt MUST return the input encryption materials unmodified.
+ if (isDiscovery) {
+ return;
+ }
+
+ // If the input encryption materials do not contain a plaintext data key and this keyring does not
+ // have a generator defined, OnEncrypt MUST not modify the encryption materials and MUST fail.
+ if (!encryptionMaterials.hasPlaintextDataKey() && generatorKeyId == null) {
+ throw new AwsCryptoException("Encryption materials must contain either a plaintext data key or a generator");
+ }
+
+ final List keyIdsToEncrypt = new ArrayList<>(keyIds);
+
+ // If the input encryption materials do not contain a plaintext data key and a generator is defined onEncrypt
+ // MUST attempt to generate a new plaintext data key and encrypt that data key by calling KMS GenerateDataKey.
+ if (!encryptionMaterials.hasPlaintextDataKey()) {
+ generateDataKey(encryptionMaterials);
+ } else if (generatorKeyId != null) {
+ // If this keyring's generator is defined and was not used to generate a data key, OnEncrypt
+ // MUST also attempt to encrypt the plaintext data key using the CMK specified by the generator.
+ keyIdsToEncrypt.add(generatorKeyId);
+ }
+
+ // Given a plaintext data key in the encryption materials, OnEncrypt MUST attempt
+ // to encrypt the plaintext data key using each CMK specified in it's key IDs list.
+ for (String keyId : keyIdsToEncrypt) {
+ encryptDataKey(keyId, encryptionMaterials);
+ }
+ }
+
+ private void generateDataKey(final EncryptionMaterials encryptionMaterials) {
+ final GenerateDataKeyResult result = dataKeyEncryptionDao.generateDataKey(generatorKeyId,
+ encryptionMaterials.getAlgorithmSuite(), encryptionMaterials.getEncryptionContext());
+
+ encryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(),
+ new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.GENERATED_DATA_KEY));
+ encryptionMaterials.addEncryptedDataKey(result.getEncryptedDataKey(),
+ new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT));
+ }
+
+ private void encryptDataKey(final String keyId, final EncryptionMaterials encryptionMaterials) {
+ final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao.encryptDataKey(keyId,
+ encryptionMaterials.getPlaintextDataKey(), encryptionMaterials.getEncryptionContext());
+
+ encryptionMaterials.addEncryptedDataKey(encryptedDataKey,
+ new KeyringTraceEntry(KMS_PROVIDER_ID, keyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT));
+ }
+
+ @Override
+ public void onDecrypt(DecryptionMaterials decryptionMaterials, List extends EncryptedDataKey> encryptedDataKeys) {
+ requireNonNull(decryptionMaterials, "decryptionMaterials are required");
+ requireNonNull(encryptedDataKeys, "encryptedDataKeys are required");
+
+ if (decryptionMaterials.hasPlaintextDataKey() || encryptedDataKeys.isEmpty()) {
+ return;
+ }
+
+ final Set configuredKeyIds = new HashSet<>(keyIds);
+
+ if (generatorKeyId != null) {
+ configuredKeyIds.add(generatorKeyId);
+ }
+
+ for (EncryptedDataKey encryptedDataKey : encryptedDataKeys) {
+ if (okToDecrypt(encryptedDataKey, configuredKeyIds)) {
+ try {
+ final DecryptDataKeyResult result = dataKeyEncryptionDao.decryptDataKey(encryptedDataKey,
+ decryptionMaterials.getAlgorithmSuite(), decryptionMaterials.getEncryptionContext());
+
+ decryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(),
+ new KeyringTraceEntry(KMS_PROVIDER_ID, result.getKeyArn(),
+ KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT));
+ return;
+ } catch (CannotUnwrapDataKeyException e) {
+ continue;
+ }
+ }
+ }
+ }
+
+ private boolean okToDecrypt(EncryptedDataKey encryptedDataKey, Set configuredKeyIds) {
+ // Only attempt to decrypt keys provided by KMS
+ if (!encryptedDataKey.getProviderId().equals(KMS_PROVIDER_ID)) {
+ return false;
+ }
+
+ // If the key ARN cannot be parsed, skip it
+ if(!isArnWellFormed(new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING)))
+ {
+ return false;
+ }
+
+ // If this keyring is a discovery keyring, OnDecrypt MUST attempt to
+ // decrypt every encrypted data key in the input encrypted data key list
+ if (isDiscovery) {
+ return true;
+ }
+
+ // OnDecrypt MUST attempt to decrypt each input encrypted data key in the input
+ // encrypted data key list where the key provider info has a value equal to one
+ // of the ARNs in this keyring's key IDs or the generator
+ return configuredKeyIds.contains(new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING));
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java
index 34b7e22f4..9dc96cbf8 100644
--- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java
+++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java
@@ -13,6 +13,9 @@
package com.amazonaws.encryptionsdk.keyrings;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
+import com.amazonaws.encryptionsdk.kms.KmsClientSupplier;
+
import javax.crypto.SecretKey;
import java.security.PrivateKey;
import java.security.PublicKey;
@@ -56,6 +59,22 @@ public static Keyring rawRsa(String keyNamespace, String keyName, PublicKey publ
return new RawRsaKeyring(keyNamespace, keyName, publicKey, privateKey, wrappingAlgorithm);
}
+ /**
+ * Constructs a {@code Keyring} which interacts with AWS Key Management Service (KMS) to create,
+ * encrypt, and decrypt data keys using KMS defined Customer Master Keys (CMKs).
+ *
+ * @param clientSupplier A function that returns a KMS client that can make GenerateDataKey,
+ * Encrypt, and Decrypt calls in a particular AWS region.
+ * @param grantTokens A list of string grant tokens to be included in all KMS calls.
+ * @param keyIds A list of strings identifying KMS CMKs, in ARN, CMK Alias, or ARN Alias format.
+ * @param generator A string that identifies a KMS CMK responsible for generating a data key,
+ * as well as encrypting and decrypting data keys in ARN, CMK Alias, or ARN Alias format.
+ * @return The {@code Keyring}
+ */
+ public static Keyring kms(KmsClientSupplier clientSupplier, List grantTokens, List keyIds, String generator) {
+ return new KmsKeyring(DataKeyEncryptionDao.kms(clientSupplier, grantTokens), keyIds, generator);
+ }
+
/**
* Constructs a {@code Keyring} which combines other keyrings, allowing one OnEncrypt or OnDecrypt call
* to modify the encryption or decryption materials using more than one keyring.
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java
new file mode 100644
index 000000000..4267ba6f8
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+
+import javax.crypto.SecretKey;
+import java.util.List;
+import java.util.Map;
+
+public interface DataKeyEncryptionDao {
+
+ /**
+ * Generates a unique data key, returning both the plaintext copy of the key and an encrypted copy encrypted using
+ * the customer master key specified by the given keyId.
+ *
+ * @param keyId The customer master key to encrypt the generated key with.
+ * @param algorithmSuite The algorithm suite associated with the key.
+ * @param encryptionContext The encryption context.
+ * @return GenerateDataKeyResult containing the plaintext data key and the encrypted data key.
+ */
+ GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext);
+
+ /**
+ * Encrypts the given plaintext data key using the customer aster key specified by the given keyId.
+ *
+ * @param keyId The customer master key to encrypt the plaintext data key with.
+ * @param plaintextDataKey The plaintext data key to encrypt.
+ * @param encryptionContext The encryption context.
+ * @return The encrypted data key.
+ */
+ EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext);
+
+ /**
+ * Decrypted the given encrypted data key.
+ *
+ * @param encryptedDataKey The encrypted data key to decrypt.
+ * @param algorithmSuite The algorithm suite associated with the key.
+ * @param encryptionContext The encryption context.
+ * @return DecryptDataKeyResult containing the plaintext data key and the ARN of the key that decrypted it.
+ */
+ DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext);
+
+ /**
+ * Constructs an instance of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for
+ * generation, encryption, and decryption of data keys.
+ *
+ * @param clientSupplier A supplier of AWSKMS clients
+ * @param grantTokens A list of grant tokens to supply to KMS
+ * @return The DataKeyEncryptionDao
+ */
+ static DataKeyEncryptionDao kms(KmsClientSupplier clientSupplier, List grantTokens) {
+ return new KmsDataKeyEncryptionDao(clientSupplier, grantTokens);
+ }
+
+ class GenerateDataKeyResult {
+ private final SecretKey plaintextDataKey;
+ private final EncryptedDataKey encryptedDataKey;
+
+ public GenerateDataKeyResult(SecretKey plaintextDataKey, EncryptedDataKey encryptedDataKey) {
+ this.plaintextDataKey = plaintextDataKey;
+ this.encryptedDataKey = encryptedDataKey;
+ }
+
+ public SecretKey getPlaintextDataKey() {
+ return plaintextDataKey;
+ }
+
+ public EncryptedDataKey getEncryptedDataKey() {
+ return encryptedDataKey;
+ }
+ }
+
+ class DecryptDataKeyResult {
+ private final String keyArn;
+ private final SecretKey plaintextDataKey;
+
+ public DecryptDataKeyResult(String keyArn, SecretKey plaintextDataKey) {
+ this.keyArn = keyArn;
+ this.plaintextDataKey = plaintextDataKey;
+ }
+
+ public String getKeyArn() {
+ return keyArn;
+ }
+
+ public SecretKey getPlaintextDataKey() {
+ return plaintextDataKey;
+ }
+
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java
new file mode 100644
index 000000000..6182e161c
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java
@@ -0,0 +1,212 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.ClientConfiguration;
+import com.amazonaws.auth.AWSCredentialsProvider;
+import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
+import com.amazonaws.services.kms.AWSKMS;
+import com.amazonaws.services.kms.AWSKMSClientBuilder;
+import com.amazonaws.services.kms.model.AWSKMSException;
+
+import javax.annotation.Nullable;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Proxy;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static java.util.Objects.requireNonNull;
+import static org.apache.commons.lang3.Validate.isTrue;
+import static org.apache.commons.lang3.Validate.notEmpty;
+
+/**
+ * Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The
+ * function should be able to handle when the region is null.
+ */
+@FunctionalInterface
+public interface KmsClientSupplier {
+
+ /**
+ * Gets an {@code AWSKMS} client for the given regionId.
+ *
+ * @param regionId The AWS region (or null)
+ * @return The AWSKMS client
+ * @throws UnsupportedRegionException if a regionId is specified that this
+ * client supplier is configured to not allow.
+ */
+ AWSKMS getClient(@Nullable String regionId) throws UnsupportedRegionException;
+
+ /**
+ * Gets a Builder for constructing a KmsClientSupplier
+ *
+ * @return The builder
+ */
+ static Builder builder() {
+ return new Builder(AWSKMSClientBuilder.standard());
+ }
+
+ /**
+ * Builder to construct a KmsClientSupplier given various
+ * optional settings.
+ */
+ class Builder {
+
+ private AWSCredentialsProvider credentialsProvider;
+ private ClientConfiguration clientConfiguration;
+ private Set allowedRegions = Collections.emptySet();
+ private Set excludedRegions = Collections.emptySet();
+ private boolean clientCachingEnabled = false;
+ private final Map clientsCache = new HashMap<>();
+ private static final Set KMS_METHODS = new HashSet<>();
+ private AWSKMSClientBuilder kmsClientBuilder;
+
+ static {
+ KMS_METHODS.add("generateDataKey");
+ KMS_METHODS.add("encrypt");
+ KMS_METHODS.add("decrypt");
+ }
+
+ Builder(AWSKMSClientBuilder kmsClientBuilder) {
+ this.kmsClientBuilder = kmsClientBuilder;
+ }
+
+ public KmsClientSupplier build() {
+ isTrue(allowedRegions.isEmpty() || excludedRegions.isEmpty(),
+ "Either allowed regions or excluded regions may be set, not both.");
+
+ return regionId -> {
+ if (!allowedRegions.isEmpty() && !allowedRegions.contains(regionId)) {
+ throw new UnsupportedRegionException(String.format("Region %s is not in the list of allowed regions %s",
+ regionId, allowedRegions));
+ }
+
+ if (excludedRegions.contains(regionId)) {
+ throw new UnsupportedRegionException(String.format("Region %s is in the list of excluded regions %s",
+ regionId, excludedRegions));
+ }
+
+ if (clientsCache.containsKey(regionId)) {
+ return clientsCache.get(regionId);
+ }
+
+ if (credentialsProvider != null) {
+ kmsClientBuilder = kmsClientBuilder.withCredentials(credentialsProvider);
+ }
+
+ if (clientConfiguration != null) {
+ kmsClientBuilder = kmsClientBuilder.withClientConfiguration(clientConfiguration);
+ }
+
+ if (regionId != null) {
+ kmsClientBuilder = kmsClientBuilder.withRegion(regionId);
+ }
+
+ AWSKMS client = kmsClientBuilder.build();
+
+ if (clientCachingEnabled) {
+ client = newCachingProxy(client, regionId);
+ }
+
+ return client;
+ };
+ }
+
+ /**
+ * Sets the AWSCredentialsProvider used by the client.
+ *
+ * @param credentialsProvider New AWSCredentialsProvider to use.
+ */
+ public Builder credentialsProvider(AWSCredentialsProvider credentialsProvider) {
+ this.credentialsProvider = credentialsProvider;
+ return this;
+ }
+
+ /**
+ * Sets the ClientConfiguration to be used by the client.
+ *
+ * @param clientConfiguration Custom configuration to use.
+ */
+ public Builder clientConfiguration(ClientConfiguration clientConfiguration) {
+ this.clientConfiguration = clientConfiguration;
+ return this;
+ }
+
+ /**
+ * Sets the AWS regions that the client supplier is permitted to use.
+ *
+ * @param regions The set of regions.
+ */
+ public Builder allowedRegions(Set regions) {
+ notEmpty(regions, "At least one region is required");
+ this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(regions));
+ return this;
+ }
+
+ /**
+ * Sets the AWS regions that the client supplier is not permitted to use.
+ *
+ * @param regions The set of regions.
+ */
+ public Builder excludedRegions(Set regions) {
+ requireNonNull(regions, "regions is required");
+ this.excludedRegions = Collections.unmodifiableSet(new HashSet<>(regions));
+ return this;
+ }
+
+ /**
+ * When set to true, allows for the AWSKMS client for each region to be cached and reused.
+ *
+ * @param enabled Whether or not caching is enabled.
+ */
+ public Builder clientCaching(boolean enabled) {
+ this.clientCachingEnabled = enabled;
+ return this;
+ }
+
+ /**
+ * Creates a proxy for the AWSKMS client that will populate the client into the client cache
+ * after a KMS method successfully completes or a KMS exception occurs. This is to prevent a
+ * a malicious user from causing a local resource DOS by sending ciphertext with a large number
+ * of spurious regions, thereby filling the cache with regions and exhausting resources.
+ *
+ * @param client The client to proxy
+ * @param regionId The region the client is associated with
+ * @return The proxy
+ */
+ private AWSKMS newCachingProxy(AWSKMS client, String regionId) {
+ return (AWSKMS) Proxy.newProxyInstance(
+ AWSKMS.class.getClassLoader(),
+ new Class[]{AWSKMS.class},
+ (proxy, method, methodArgs) -> {
+ try {
+ final Object result = method.invoke(client, methodArgs);
+ if (KMS_METHODS.contains(method.getName())) {
+ clientsCache.put(regionId, client);
+ }
+ return result;
+ } catch (InvocationTargetException e) {
+ if (e.getTargetException() instanceof AWSKMSException &&
+ KMS_METHODS.contains(method.getName())) {
+ clientsCache.put(regionId, client);
+ }
+
+ throw e.getTargetException();
+ }
+ });
+ }
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java
new file mode 100644
index 000000000..a00a1f1c7
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java
@@ -0,0 +1,172 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.AmazonServiceException;
+import com.amazonaws.AmazonWebServiceRequest;
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
+import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
+import com.amazonaws.encryptionsdk.internal.VersionInfo;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+import com.amazonaws.services.kms.model.DecryptRequest;
+import com.amazonaws.services.kms.model.EncryptRequest;
+import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.getClientByArn;
+import static java.util.Objects.requireNonNull;
+import static org.apache.commons.lang3.Validate.isTrue;
+
+/**
+ * An implementation of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for
+ * generation, encryption, and decryption of data keys. The KmsMethods interface is implemented
+ * to allow usage in KmsMasterKey.
+ */
+class KmsDataKeyEncryptionDao implements DataKeyEncryptionDao, KmsMethods {
+
+ private final KmsClientSupplier clientSupplier;
+ private List grantTokens;
+
+ KmsDataKeyEncryptionDao(KmsClientSupplier clientSupplier, List grantTokens) {
+ requireNonNull(clientSupplier, "clientSupplier is required");
+
+ this.clientSupplier = clientSupplier;
+ this.grantTokens = grantTokens == null ? new ArrayList<>() : new ArrayList<>(grantTokens);
+ }
+
+ @Override
+ public GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext) {
+ requireNonNull(keyId, "keyId is required");
+ requireNonNull(algorithmSuite, "algorithmSuite is required");
+ requireNonNull(encryptionContext, "encryptionContext is required");
+
+ final com.amazonaws.services.kms.model.GenerateDataKeyResult kmsResult;
+
+ try {
+ kmsResult = getClientByArn(keyId, clientSupplier)
+ .generateDataKey(updateUserAgent(
+ new GenerateDataKeyRequest()
+ .withKeyId(keyId)
+ .withNumberOfBytes(algorithmSuite.getDataKeyLength())
+ .withEncryptionContext(encryptionContext)
+ .withGrantTokens(grantTokens)));
+ } catch (final AmazonServiceException ex) {
+ throw new AwsCryptoException(ex);
+ }
+
+ final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()];
+ kmsResult.getPlaintext().get(rawKey);
+ if (kmsResult.getPlaintext().remaining() > 0) {
+ throw new IllegalStateException("Received an unexpected number of bytes from KMS");
+ }
+ final byte[] encryptedKey = new byte[kmsResult.getCiphertextBlob().remaining()];
+ kmsResult.getCiphertextBlob().get(encryptedKey);
+
+ return new GenerateDataKeyResult(new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo()),
+ new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedKey));
+ }
+
+ @Override
+ public EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext) {
+ requireNonNull(keyId, "keyId is required");
+ requireNonNull(plaintextDataKey, "plaintextDataKey is required");
+ requireNonNull(encryptionContext, "encryptionContext is required");
+ isTrue(plaintextDataKey.getFormat().equals("RAW"), "Only RAW encoded keys are supported");
+
+ final com.amazonaws.services.kms.model.EncryptResult kmsResult;
+
+ try {
+ kmsResult = getClientByArn(keyId, clientSupplier)
+ .encrypt(updateUserAgent(new EncryptRequest()
+ .withKeyId(keyId)
+ .withPlaintext(ByteBuffer.wrap(plaintextDataKey.getEncoded()))
+ .withEncryptionContext(encryptionContext)
+ .withGrantTokens(grantTokens)));
+ } catch (final AmazonServiceException ex) {
+ throw new AwsCryptoException(ex);
+ }
+ final byte[] encryptedDataKey = new byte[kmsResult.getCiphertextBlob().remaining()];
+ kmsResult.getCiphertextBlob().get(encryptedDataKey);
+
+ return new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedDataKey);
+
+ }
+
+ @Override
+ public DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext) {
+ requireNonNull(encryptedDataKey, "encryptedDataKey is required");
+ requireNonNull(algorithmSuite, "algorithmSuite is required");
+ requireNonNull(encryptionContext, "encryptionContext is required");
+
+ final String providerInformation = new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING);
+ final com.amazonaws.services.kms.model.DecryptResult kmsResult;
+
+ try {
+ kmsResult = getClientByArn(providerInformation, clientSupplier)
+ .decrypt(updateUserAgent(new DecryptRequest()
+ .withCiphertextBlob(ByteBuffer.wrap(encryptedDataKey.getEncryptedDataKey()))
+ .withEncryptionContext(encryptionContext)
+ .withGrantTokens(grantTokens)));
+ } catch (final AmazonServiceException | UnsupportedRegionException ex) {
+ throw new CannotUnwrapDataKeyException(ex);
+ }
+
+ if (!kmsResult.getKeyId().equals(providerInformation)) {
+ throw new MismatchedDataKeyException("Received an unexpected key Id from KMS");
+ }
+
+ final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()];
+ kmsResult.getPlaintext().get(rawKey);
+ if (kmsResult.getPlaintext().remaining() > 0) {
+ throw new IllegalStateException("Received an unexpected number of bytes from KMS");
+ }
+
+ return new DecryptDataKeyResult(kmsResult.getKeyId(), new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo()));
+
+ }
+
+ private T updateUserAgent(T request) {
+ request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT);
+
+ return request;
+ }
+
+ @Override
+ public void setGrantTokens(List grantTokens) {
+ this.grantTokens = new ArrayList<>(grantTokens);
+ }
+
+ @Override
+ public List getGrantTokens() {
+ return Collections.unmodifiableList(grantTokens);
+ }
+
+ @Override
+ public void addGrantToken(String grantToken) {
+ grantTokens.add(grantToken);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java
index b78840221..60c69445c 100644
--- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java
@@ -14,17 +14,12 @@
package com.amazonaws.encryptionsdk.kms;
import javax.crypto.SecretKey;
-import javax.crypto.spec.SecretKeySpec;
-import java.nio.ByteBuffer;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
-import com.amazonaws.AmazonServiceException;
-import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.encryptionsdk.AwsCrypto;
@@ -35,30 +30,18 @@
import com.amazonaws.encryptionsdk.MasterKeyProvider;
import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException;
-import com.amazonaws.encryptionsdk.internal.VersionInfo;
import com.amazonaws.services.kms.AWSKMS;
-import com.amazonaws.services.kms.model.DecryptRequest;
-import com.amazonaws.services.kms.model.DecryptResult;
-import com.amazonaws.services.kms.model.EncryptRequest;
-import com.amazonaws.services.kms.model.EncryptResult;
-import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
-import com.amazonaws.services.kms.model.GenerateDataKeyResult;
+
+import static java.util.Collections.emptyList;
/**
* Represents a single Customer Master Key (CMK) and is used to encrypt/decrypt data with
* {@link AwsCrypto}.
*/
public final class KmsMasterKey extends MasterKey implements KmsMethods {
- private final Supplier kms_;
+ private final KmsDataKeyEncryptionDao dataKeyEncryptionDao_;
private final MasterKeyProvider sourceProvider_;
private final String id_;
- private final List grantTokens_ = new ArrayList<>();
-
- private T updateUserAgent(T request) {
- request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT);
-
- return request;
- }
/**
*
@@ -80,11 +63,11 @@ public static KmsMasterKey getInstance(final AWSCredentialsProvider creds, final
static KmsMasterKey getInstance(final Supplier kms, final String id,
final MasterKeyProvider provider) {
- return new KmsMasterKey(kms, id, provider);
+ return new KmsMasterKey(new KmsDataKeyEncryptionDao(s -> kms.get(), emptyList()), id, provider);
}
- private KmsMasterKey(final Supplier kms, final String id, final MasterKeyProvider provider) {
- kms_ = kms;
+ KmsMasterKey(final KmsDataKeyEncryptionDao dataKeyEncryptionDao, final String id, final MasterKeyProvider provider) {
+ dataKeyEncryptionDao_ = dataKeyEncryptionDao;
id_ = id;
sourceProvider_ = provider;
}
@@ -102,39 +85,27 @@ public String getKeyId() {
@Override
public DataKey generateDataKey(final CryptoAlgorithm algorithm,
final Map encryptionContext) {
- final GenerateDataKeyResult gdkResult = kms_.get().generateDataKey(updateUserAgent(
- new GenerateDataKeyRequest()
- .withKeyId(getKeyId())
- .withNumberOfBytes(algorithm.getDataKeyLength())
- .withEncryptionContext(encryptionContext)
- .withGrantTokens(grantTokens_)
- ));
- final byte[] rawKey = new byte[algorithm.getDataKeyLength()];
- gdkResult.getPlaintext().get(rawKey);
- if (gdkResult.getPlaintext().remaining() > 0) {
- throw new IllegalStateException("Recieved an unexpected number of bytes from KMS");
- }
- final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()];
- gdkResult.getCiphertextBlob().get(encryptedKey);
-
- final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo());
- return new DataKey<>(key, encryptedKey, gdkResult.getKeyId().getBytes(StandardCharsets.UTF_8), this);
+ final DataKeyEncryptionDao.GenerateDataKeyResult gdkResult = dataKeyEncryptionDao_.generateDataKey(
+ getKeyId(), algorithm, encryptionContext);
+ return new DataKey<>(gdkResult.getPlaintextDataKey(),
+ gdkResult.getEncryptedDataKey().getEncryptedDataKey(),
+ gdkResult.getEncryptedDataKey().getProviderInformation(),
+ this);
}
@Override
public void setGrantTokens(final List grantTokens) {
- grantTokens_.clear();
- grantTokens_.addAll(grantTokens);
+ dataKeyEncryptionDao_.setGrantTokens(grantTokens);
}
@Override
public List getGrantTokens() {
- return grantTokens_;
+ return dataKeyEncryptionDao_.getGrantTokens();
}
@Override
public void addGrantToken(final String grantToken) {
- grantTokens_.add(grantToken);
+ dataKeyEncryptionDao_.addGrantToken(grantToken);
}
@Override
@@ -142,22 +113,12 @@ public DataKey encryptDataKey(final CryptoAlgorithm algorithm,
final Map encryptionContext,
final DataKey> dataKey) {
final SecretKey key = dataKey.getKey();
- if (!key.getFormat().equals("RAW")) {
- throw new IllegalArgumentException("Only RAW encoded keys are supported");
- }
- try {
- final EncryptResult encryptResult = kms_.get().encrypt(updateUserAgent(
- new EncryptRequest()
- .withKeyId(id_)
- .withPlaintext(ByteBuffer.wrap(key.getEncoded()))
- .withEncryptionContext(encryptionContext)
- .withGrantTokens(grantTokens_)));
- final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()];
- encryptResult.getCiphertextBlob().get(edk);
- return new DataKey<>(dataKey.getKey(), edk, encryptResult.getKeyId().getBytes(StandardCharsets.UTF_8), this);
- } catch (final AmazonServiceException asex) {
- throw new AwsCryptoException(asex);
- }
+ final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao_.encryptDataKey(id_, key, encryptionContext);
+
+ return new DataKey<>(dataKey.getKey(),
+ encryptedDataKey.getEncryptedDataKey(),
+ encryptedDataKey.getProviderInformation(),
+ this);
}
@Override
@@ -168,24 +129,13 @@ public DataKey decryptDataKey(final CryptoAlgorithm algorithm,
final List exceptions = new ArrayList<>();
for (final EncryptedDataKey edk : encryptedDataKeys) {
try {
- final DecryptResult decryptResult = kms_.get().decrypt(updateUserAgent(
- new DecryptRequest()
- .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey()))
- .withEncryptionContext(encryptionContext)
- .withGrantTokens(grantTokens_)));
- if (decryptResult.getKeyId().equals(id_)) {
- final byte[] rawKey = new byte[algorithm.getDataKeyLength()];
- decryptResult.getPlaintext().get(rawKey);
- if (decryptResult.getPlaintext().remaining() > 0) {
- throw new IllegalStateException("Received an unexpected number of bytes from KMS");
- }
- return new DataKey<>(
- new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()),
- edk.getEncryptedDataKey(),
- edk.getProviderInformation(), this);
- }
- } catch (final AmazonServiceException awsex) {
- exceptions.add(awsex);
+ final DataKeyEncryptionDao.DecryptDataKeyResult result = dataKeyEncryptionDao_.decryptDataKey(edk, algorithm, encryptionContext);
+ return new DataKey<>(
+ result.getPlaintextDataKey(),
+ edk.getEncryptedDataKey(),
+ edk.getProviderInformation(), this);
+ } catch (final AwsCryptoException ex) {
+ exceptions.add(ex);
}
}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java
new file mode 100644
index 000000000..f6aab16eb
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.arn.Arn;
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.services.kms.AWSKMS;
+
+public class KmsUtils {
+
+ private static final String ALIAS_PREFIX = "alias/";
+ private static final String ARN_PREFIX = "arn:";
+ /**
+ * The provider ID used for the KmsKeyring
+ */
+ public static final String KMS_PROVIDER_ID = "aws-kms";
+
+ /**
+ * Parses region from the given arn (if possible) and passes that region to the
+ * given clientSupplier to produce an {@code AWSKMS} client.
+ *
+ * @param arn The Amazon Resource Name or Key Alias
+ * @param clientSupplier The client supplier
+ * @return AWSKMS The client
+ * @throws MalformedArnException if the arn is malformed
+ */
+ public static AWSKMS getClientByArn(String arn, KmsClientSupplier clientSupplier) throws MalformedArnException {
+ if (isKeyAlias(arn)) {
+ return clientSupplier.getClient(null);
+ }
+
+ if(isArn(arn)) {
+ try {
+ return clientSupplier.getClient(Arn.fromString(arn).getRegion());
+ } catch (IllegalArgumentException e) {
+ throw new MalformedArnException(e);
+ }
+ }
+
+ // Not an alias or an ARN, must be a raw Key ID
+ return clientSupplier.getClient(null);
+ }
+
+ /**
+ * Returns true if the given arn is a well formed Amazon Resource Name or Key Alias. Does
+ * not return true for raw key IDs.
+ *
+ * @param arn The Amazon Resource Name or Key Alias
+ * @return True if well formed, false otherwise
+ */
+ public static boolean isArnWellFormed(String arn) {
+ if (isKeyAlias(arn)) {
+ return true;
+ }
+
+ try {
+ Arn.fromString(arn);
+ return true;
+ } catch (IllegalArgumentException e) {
+ return false;
+ }
+ }
+
+ private static boolean isKeyAlias(String arn) {
+ return arn.startsWith(ALIAS_PREFIX);
+ }
+
+ private static boolean isArn(String arn) {
+ return arn.startsWith(ARN_PREFIX);
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java
new file mode 100644
index 000000000..6ed3629ff
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java
@@ -0,0 +1,334 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.keyrings;
+
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
+import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.ENCRYPTED_DATA_KEY;
+import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.GENERATED_DATA_KEY;
+import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class KmsKeyringTest {
+
+ private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256;
+ private static final SecretKey PLAINTEXT_DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo());
+ private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue");
+ private static final String GENERATOR_KEY_ID = "arn:aws:kms:us-east-1:999999999999:key/generator-89ab-cdef-fedc-ba9876543210";
+ private static final String KEY_ID_1 = "arn:aws:kms:us-east-1:999999999999:key/key1-23bv-sdfs-werw-234323nfdsf";
+ private static final String KEY_ID_2 = "arn:aws:kms:us-east-1:999999999999:key/key2-02ds-wvjs-aswe-a4923489273";
+ private static final EncryptedDataKey ENCRYPTED_GENERATOR_KEY = new KeyBlob(KMS_PROVIDER_ID,
+ GENERATOR_KEY_ID.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+ private static final EncryptedDataKey ENCRYPTED_KEY_1 = new KeyBlob(KMS_PROVIDER_ID,
+ KEY_ID_1.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+ private static final EncryptedDataKey ENCRYPTED_KEY_2 = new KeyBlob(KMS_PROVIDER_ID,
+ KEY_ID_2.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+ private static final KeyringTraceEntry ENCRYPTED_GENERATOR_KEY_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT);
+ private static final KeyringTraceEntry ENCRYPTED_KEY_1_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT);
+ private static final KeyringTraceEntry ENCRYPTED_KEY_2_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT);
+ private static final KeyringTraceEntry GENERATED_DATA_KEY_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, GENERATED_DATA_KEY);
+ @Mock(lenient = true) private DataKeyEncryptionDao dataKeyEncryptionDao;
+ private Keyring keyring;
+
+ @BeforeEach
+ void setup() {
+ when(dataKeyEncryptionDao.encryptDataKey(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_GENERATOR_KEY);
+ when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_1, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_1);
+ when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_2, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_2);
+
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_GENERATOR_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DecryptDataKeyResult(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY));
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DecryptDataKeyResult(KEY_ID_1, PLAINTEXT_DATA_KEY));
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_2, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DecryptDataKeyResult(KEY_ID_2, PLAINTEXT_DATA_KEY));
+
+ List keyIds = new ArrayList<>();
+ keyIds.add(KEY_ID_1);
+ keyIds.add(KEY_ID_2);
+ keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, GENERATOR_KEY_ID);
+ }
+
+ @Test
+ void testMalformedArns() {
+ assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, null, "badArn"));
+ assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList("badArn"), GENERATOR_KEY_ID));
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(new KeyBlob(KMS_PROVIDER_ID, "badArn".getBytes(PROVIDER_ENCODING), new byte[]{}));
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ // Malformed Arn for a non KMS provider shouldn't fail
+ encryptedDataKeys.clear();
+ encryptedDataKeys.add(new KeyBlob("OtherProviderId", "badArn".getBytes(PROVIDER_ENCODING), new byte[]{}));
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+ }
+
+ @Test
+ void testGeneratorKeyInKeyIds() {
+ assertThrows(IllegalArgumentException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList(GENERATOR_KEY_ID), GENERATOR_KEY_ID));
+ }
+
+ @Test
+ void testEncryptDecryptExistingDataKey() {
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .plaintextDataKey(PLAINTEXT_DATA_KEY)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertEquals(3, encryptionMaterials.getEncryptedDataKeys().size());
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_GENERATOR_KEY));
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_1));
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_2));
+
+ assertEquals(3, encryptionMaterials.getKeyringTrace().getEntries().size());
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE));
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY);
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testEncryptNullDataKey() {
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .keyringTrace(new KeyringTrace())
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ when(dataKeyEncryptionDao.generateDataKey(GENERATOR_KEY_ID, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenReturn(new DataKeyEncryptionDao.GenerateDataKeyResult(PLAINTEXT_DATA_KEY, ENCRYPTED_GENERATOR_KEY));
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertEquals(PLAINTEXT_DATA_KEY, encryptionMaterials.getPlaintextDataKey());
+
+ assertEquals(4, encryptionMaterials.getKeyringTrace().getEntries().size());
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(GENERATED_DATA_KEY_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE));
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY);
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testEncryptNullGenerator() {
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .keyringTrace(new KeyringTrace())
+ .plaintextDataKey(PLAINTEXT_DATA_KEY)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ Keyring keyring = new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList(KEY_ID_1), null);
+
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertEquals(1, encryptionMaterials.getEncryptedDataKeys().size());
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_1));
+
+ assertEquals(PLAINTEXT_DATA_KEY, encryptionMaterials.getPlaintextDataKey());
+
+ assertEquals(1, encryptionMaterials.getKeyringTrace().getEntries().size());
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE));
+ }
+
+ @Test
+ void testDiscoveryEncrypt() {
+ keyring = new KmsKeyring(dataKeyEncryptionDao, null, null);
+
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertFalse(encryptionMaterials.hasPlaintextDataKey());
+ assertEquals(0, encryptionMaterials.getKeyringTrace().getEntries().size());
+ }
+
+ @Test
+ void testEncryptNoGeneratorOrPlaintextDataKey() {
+ List keyIds = new ArrayList<>();
+ keyIds.add(KEY_ID_1);
+ keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, null);
+
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE).build();
+ assertThrows(AwsCryptoException.class, () -> keyring.onEncrypt(encryptionMaterials));
+ }
+
+ @Test
+ void testDecryptFirstKeyFails() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenThrow(new CannotUnwrapDataKeyException());
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDecryptMismatchedDataKeyException() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenThrow(new MismatchedDataKeyException());
+
+ assertThrows(MismatchedDataKeyException.class, () -> keyring.onDecrypt(decryptionMaterials, Collections.singletonList(ENCRYPTED_KEY_1)));
+ }
+
+ @Test
+ void testDecryptFirstKeyWrongProvider() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ EncryptedDataKey wrongProviderKey = new KeyBlob("OtherProvider", KEY_ID_1.getBytes(PROVIDER_ENCODING), new byte[]{});
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(wrongProviderKey);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDiscoveryDecrypt() {
+ keyring = new KmsKeyring(dataKeyEncryptionDao, null, null);
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDecryptAlreadyDecryptedDataKey() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .plaintextDataKey(PLAINTEXT_DATA_KEY)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ keyring.onDecrypt(decryptionMaterials, Collections.singletonList(ENCRYPTED_GENERATOR_KEY));
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+ assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size());
+ }
+
+ @Test
+ void testDecryptNoDataKey() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ keyring.onDecrypt(decryptionMaterials, Collections.emptyList());
+
+ assertFalse(decryptionMaterials.hasPlaintextDataKey());
+ assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size());
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplierTest.java
new file mode 100644
index 000000000..cdca99fd6
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplierTest.java
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.ClientConfiguration;
+import com.amazonaws.auth.AWSCredentialsProvider;
+import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
+import com.amazonaws.services.kms.AWSKMS;
+import com.amazonaws.services.kms.AWSKMSClientBuilder;
+import com.amazonaws.services.kms.model.AWSKMSException;
+import com.amazonaws.services.kms.model.DecryptRequest;
+import com.amazonaws.services.kms.model.EncryptRequest;
+import com.amazonaws.services.kms.model.EncryptResult;
+import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
+import com.amazonaws.services.kms.model.KMSInvalidStateException;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.util.Collections;
+
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class KmsClientSupplierTest {
+
+ @Mock AWSKMSClientBuilder kmsClientBuilder;
+ @Mock AWSKMS awskms;
+ @Mock EncryptRequest encryptRequest;
+ @Mock DecryptRequest decryptRequest;
+ @Mock GenerateDataKeyRequest generateDataKeyRequest;
+ @Mock AWSCredentialsProvider credentialsProvider;
+ @Mock ClientConfiguration clientConfiguration;
+ private static final String REGION_1 = "us-east-1";
+ private static final String REGION_2 = "us-west-2";
+ private static final String REGION_3 = "eu-west-1";
+
+ @Test
+ void testCredentialsAndClientConfiguration() {
+ when(kmsClientBuilder.withClientConfiguration(clientConfiguration)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.withCredentials(credentialsProvider)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ KmsClientSupplier supplier = new KmsClientSupplier.Builder(kmsClientBuilder)
+ .credentialsProvider(credentialsProvider)
+ .clientConfiguration(clientConfiguration)
+ .build();
+
+ supplier.getClient(null);
+
+ verify(kmsClientBuilder).withClientConfiguration(clientConfiguration);
+ verify(kmsClientBuilder).withCredentials(credentialsProvider);
+ verify(kmsClientBuilder).build();
+ }
+
+ @Test
+ void testAllowedAndExcludedRegions() {
+ KmsClientSupplier supplierWithDefaultValues = new KmsClientSupplier.Builder(kmsClientBuilder)
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertNotNull(supplierWithDefaultValues.getClient(REGION_1));
+
+ KmsClientSupplier supplierWithAllowed = new KmsClientSupplier.Builder(kmsClientBuilder)
+ .allowedRegions(Collections.singleton(REGION_1))
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertNotNull(supplierWithAllowed.getClient(REGION_1));
+ assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2));
+
+ KmsClientSupplier supplierWithExcluded = new KmsClientSupplier.Builder(kmsClientBuilder)
+ .excludedRegions(Collections.singleton(REGION_1))
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ assertThrows(UnsupportedRegionException.class, () -> supplierWithExcluded.getClient(REGION_1));
+ assertNotNull(supplierWithExcluded.getClient(REGION_2));
+
+ assertThrows(IllegalArgumentException.class, () -> new KmsClientSupplier.Builder(kmsClientBuilder)
+ .allowedRegions(Collections.singleton(REGION_1))
+ .excludedRegions(Collections.singleton(REGION_2))
+ .build());
+ }
+
+ @Test
+ void testClientCachingDisabled() {
+ KmsClientSupplier supplierCachingDisabled = new KmsClientSupplier.Builder(kmsClientBuilder)
+ .clientCaching(false)
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ AWSKMS uncachedClient = supplierCachingDisabled.getClient(REGION_1);
+ verify(kmsClientBuilder, times(1)).build();
+
+ when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult());
+
+ uncachedClient.encrypt(encryptRequest);
+ supplierCachingDisabled.getClient(REGION_1);
+ verify(kmsClientBuilder, times(2)).build();
+ }
+
+ @Test
+ void testClientCaching() {
+ KmsClientSupplier supplier = new KmsClientSupplier.Builder(kmsClientBuilder)
+ .clientCaching(true)
+ .build();
+
+ when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.withRegion(REGION_3)).thenReturn(kmsClientBuilder);
+ when(kmsClientBuilder.build()).thenReturn(awskms);
+
+ AWSKMS client = supplier.getClient(REGION_1);
+ AWSKMS client2 = supplier.getClient(REGION_2);
+ AWSKMS client3 = supplier.getClient(REGION_3);
+ verify(kmsClientBuilder, times(3)).build();
+
+ // No KMS methods have been called yet, so clients remain uncached
+ supplier.getClient(REGION_1);
+ supplier.getClient(REGION_2);
+ supplier.getClient(REGION_3);
+ verify(kmsClientBuilder, times(6)).build();
+
+ when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult());
+ when(awskms.decrypt(decryptRequest)).thenThrow(new KMSInvalidStateException("test"));
+ when(awskms.generateDataKey(generateDataKeyRequest)).thenThrow(new IllegalArgumentException("test"));
+
+ // Successful KMS call, client is cached
+ client.encrypt(encryptRequest);
+ supplier.getClient(REGION_1);
+ verify(kmsClientBuilder, times(6)).build();
+
+ // KMS call resulted in KMS exception, client is cached
+ assertThrows(AWSKMSException.class, () -> client2.decrypt(decryptRequest));
+ supplier.getClient(REGION_2);
+ verify(kmsClientBuilder, times(6)).build();
+
+ // KMS call resulted in non-KMS exception, client is not cached
+ assertThrows(IllegalArgumentException.class, () -> client3.generateDataKey(generateDataKeyRequest));
+ supplier.getClient(REGION_3);
+ verify(kmsClientBuilder, times(7)).build();
+
+ // Non-KMS method, client is not cached
+ client3.toString();
+ supplier.getClient(REGION_3);
+ verify(kmsClientBuilder, times(8)).build();
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java
new file mode 100644
index 000000000..e9dc140d7
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java
@@ -0,0 +1,243 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.AmazonWebServiceRequest;
+import com.amazonaws.RequestClientOptions;
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
+import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException;
+import com.amazonaws.encryptionsdk.internal.VersionInfo;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+import com.amazonaws.services.kms.AWSKMS;
+import com.amazonaws.services.kms.model.DecryptRequest;
+import com.amazonaws.services.kms.model.DecryptResult;
+import com.amazonaws.services.kms.model.EncryptRequest;
+import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
+import com.amazonaws.services.kms.model.KMSInvalidStateException;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class KmsDataKeyEncryptionDaoTest {
+
+ private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256;
+ private static final SecretKey DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo());
+ private static final List GRANT_TOKENS = Collections.singletonList("testGrantToken");
+ private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue");
+ private static final EncryptedDataKey ENCRYPTED_DATA_KEY = new KeyBlob(KMS_PROVIDER_ID,
+ "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210".getBytes(EncryptedDataKey.PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+
+ @Test
+ void testEncryptAndDecrypt() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ EncryptedDataKey encryptedDataKeyResult = dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class);
+ verify(client, times(1)).encrypt(er.capture());
+
+ EncryptRequest actualRequest = er.getValue();
+
+ assertEquals(keyId, actualRequest.getKeyId());
+ assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext());
+ assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array());
+ assertUserAgent(actualRequest);
+
+ assertEquals(KMS_PROVIDER_ID, encryptedDataKeyResult.getProviderId());
+ assertArrayEquals(keyId.getBytes(EncryptedDataKey.PROVIDER_ENCODING), encryptedDataKeyResult.getProviderInformation());
+ assertNotNull(encryptedDataKeyResult.getEncryptedDataKey());
+
+ DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(encryptedDataKeyResult, ALGORITHM_SUITE, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
+ verify(client, times(1)).decrypt(decrypt.capture());
+
+ DecryptRequest actualDecryptRequest = decrypt.getValue();
+ assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext());
+ assertArrayEquals(encryptedDataKeyResult.getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array());
+ assertUserAgent(actualDecryptRequest);
+
+ assertEquals(DATA_KEY, decryptDataKeyResult.getPlaintextDataKey());
+ assertEquals(keyId, decryptDataKeyResult.getKeyArn());
+ }
+
+ @Test
+ void testGenerateAndDecrypt() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ DataKeyEncryptionDao.GenerateDataKeyResult generateDataKeyResult = dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor gr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class);
+ verify(client, times(1)).generateDataKey(gr.capture());
+
+ GenerateDataKeyRequest actualRequest = gr.getValue();
+
+ assertEquals(keyId, actualRequest.getKeyId());
+ assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext());
+ assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes());
+ assertUserAgent(actualRequest);
+
+ assertNotNull(generateDataKeyResult.getPlaintextDataKey());
+ assertEquals(ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResult.getPlaintextDataKey().getEncoded().length);
+ assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResult.getPlaintextDataKey().getAlgorithm());
+ assertNotNull(generateDataKeyResult.getEncryptedDataKey());
+
+ DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(generateDataKeyResult.getEncryptedDataKey(), ALGORITHM_SUITE, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
+ verify(client, times(1)).decrypt(decrypt.capture());
+
+ DecryptRequest actualDecryptRequest = decrypt.getValue();
+ assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext());
+ assertArrayEquals(generateDataKeyResult.getEncryptedDataKey().getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array());
+ assertUserAgent(actualDecryptRequest);
+
+ assertEquals(generateDataKeyResult.getPlaintextDataKey(), decryptDataKeyResult.getPlaintextDataKey());
+ assertEquals(keyId, decryptDataKeyResult.getKeyArn());
+ }
+
+ @Test
+ void testEncryptWithRawKeyId() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ String rawKeyId = keyId.split("/")[1];
+ EncryptedDataKey encryptedDataKeyResult = dao.encryptDataKey(rawKeyId, DATA_KEY, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class);
+ verify(client, times(1)).encrypt(er.capture());
+
+ EncryptRequest actualRequest = er.getValue();
+
+ assertEquals(rawKeyId, actualRequest.getKeyId());
+ assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext());
+ assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array());
+ assertUserAgent(actualRequest);
+
+ assertEquals(KMS_PROVIDER_ID, encryptedDataKeyResult.getProviderId());
+ assertArrayEquals(keyId.getBytes(EncryptedDataKey.PROVIDER_ENCODING), encryptedDataKeyResult.getProviderInformation());
+ assertNotNull(encryptedDataKeyResult.getEncryptedDataKey());
+ }
+
+ @Test
+ void testEncryptWrongKeyFormat() {
+ SecretKey key = mock(SecretKey.class);
+ when(key.getFormat()).thenReturn("BadFormat");
+
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+
+ assertThrows(IllegalArgumentException.class, () -> dao.encryptDataKey(keyId, key, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testKmsFailure() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ doThrow(new KMSInvalidStateException("fail")).when(client).generateDataKey(isA(GenerateDataKeyRequest.class));
+ doThrow(new KMSInvalidStateException("fail")).when(client).encrypt(isA(EncryptRequest.class));
+ doThrow(new KMSInvalidStateException("fail")).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(AwsCryptoException.class, () -> dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ assertThrows(AwsCryptoException.class, () -> dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT));
+ assertThrows(AwsCryptoException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testUnsupportedRegionException() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ doThrow(new UnsupportedRegionException("fail")).when(client).generateDataKey(isA(GenerateDataKeyRequest.class));
+ doThrow(new UnsupportedRegionException("fail")).when(client).encrypt(isA(EncryptRequest.class));
+ doThrow(new UnsupportedRegionException("fail")).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(AwsCryptoException.class, () -> dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ assertThrows(AwsCryptoException.class, () -> dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT));
+ assertThrows(AwsCryptoException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testDecryptBadKmsKeyId() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ DecryptResult badResult = new DecryptResult();
+ badResult.setKeyId("badKeyId");
+
+ doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(MismatchedDataKeyException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testDecryptBadKmsKeyLength() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ DecryptResult badResult = new DecryptResult();
+ badResult.setKeyId(new String(ENCRYPTED_DATA_KEY.getProviderInformation(), EncryptedDataKey.PROVIDER_ENCODING));
+ badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1));
+
+ doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(IllegalStateException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ private void assertUserAgent(AmazonWebServiceRequest request) {
+ assertTrue(request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT)
+ .contains(VersionInfo.USER_AGENT));
+ }
+
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java
new file mode 100644
index 000000000..e4c731ee7
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.DataKey;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class KmsMasterKeyTest {
+
+ private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384;
+ private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("test", "value");
+ private static final String CMK_ARN = "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210";
+ @Mock KmsDataKeyEncryptionDao dataKeyEncryptionDao;
+
+ /**
+ * Test that when decryption of an encrypted data key throws a MismatchedDataKeyException, this
+ * key is skipped and another key in the list of keys is decrypted.
+ */
+ @Test
+ void testMismatchedDataKeyException() {
+ EncryptedDataKey encryptedDataKey1 = new KeyBlob(KMS_PROVIDER_ID, "KeyId1".getBytes(PROVIDER_ENCODING), generate(64));
+ EncryptedDataKey encryptedDataKey2 = new KeyBlob(KMS_PROVIDER_ID, "KeyId2".getBytes(PROVIDER_ENCODING), generate(64));
+ SecretKey secretKey = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo());
+
+ when(dataKeyEncryptionDao.decryptDataKey(encryptedDataKey1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenThrow(new MismatchedDataKeyException());
+ when(dataKeyEncryptionDao.decryptDataKey(encryptedDataKey2, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DataKeyEncryptionDao.DecryptDataKeyResult("KeyId2", secretKey));
+
+ KmsMasterKey kmsMasterKey = new KmsMasterKey(dataKeyEncryptionDao, CMK_ARN, null);
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(encryptedDataKey1);
+ encryptedDataKeys.add(encryptedDataKey2);
+
+ DataKey result = kmsMasterKey.decryptDataKey(ALGORITHM_SUITE, encryptedDataKeys, ENCRYPTION_CONTEXT);
+
+ assertEquals(secretKey, result.getKey());
+ }
+
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java
new file mode 100644
index 000000000..e8bd05477
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.services.kms.AWSKMS;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+@ExtendWith(MockitoExtension.class)
+class KmsUtilsTest {
+
+ private static final String VALID_ARN = "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210";
+ private static final String VALID_ALIAS_ARN = "arn:aws:kms:us-east-1:999999999999:alias/MyCryptoKey";
+ private static final String VALID_ALIAS = "alias/MyCryptoKey";
+ private static final String VALID_RAW_KEY_ID = "01234567-89ab-cdef-fedc-ba9876543210";
+
+ @Mock
+ private AWSKMS client;
+
+
+ @Test
+ void testGetClientByArn() {
+ assertEquals(client, KmsUtils.getClientByArn(VALID_ARN, s -> client));
+ assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS_ARN, s -> client));
+ assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS, s -> client));
+ assertThrows(MalformedArnException.class, () -> KmsUtils.getClientByArn("arn:invalid", s -> client));
+ assertEquals(client, KmsUtils.getClientByArn(VALID_RAW_KEY_ID, s -> client));
+ }
+
+ @Test
+ void testIsArnWellFormed() {
+ assertTrue(KmsUtils.isArnWellFormed(VALID_ARN));
+ assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS_ARN));
+ assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS));
+ assertFalse(KmsUtils.isArnWellFormed(VALID_RAW_KEY_ID));
+ assertFalse(KmsUtils.isArnWellFormed("arn:invalid"));
+
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java
index 37fe9cbff..00ce5c074 100644
--- a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java
@@ -29,7 +29,7 @@
import com.amazonaws.ResponseMetadata;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
-import com.amazonaws.services.kms.AWSKMSClient;
+import com.amazonaws.services.kms.AbstractAWSKMS;
import com.amazonaws.services.kms.model.CreateAliasRequest;
import com.amazonaws.services.kms.model.CreateAliasResult;
import com.amazonaws.services.kms.model.CreateGrantRequest;
@@ -85,7 +85,7 @@
import com.amazonaws.services.kms.model.UpdateKeyDescriptionRequest;
import com.amazonaws.services.kms.model.UpdateKeyDescriptionResult;
-public class MockKMSClient extends AWSKMSClient {
+public class MockKMSClient extends AbstractAWSKMS {
private static final SecureRandom rnd = new SecureRandom();
private static final String ACCOUNT_ID = "01234567890";
private final Map results_ = new HashMap<>();