diff --git a/pom.xml b/pom.xml index c426db5e6..a8774878b 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ com.amazonaws aws-java-sdk - 1.11.561 + 1.11.677 true diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java new file mode 100644 index 000000000..58f78833c --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.exception; + +/** + * This exception is thrown when an Amazon Resource Name is provided that does not + * match the CMK Alias or ARN format. + */ +public class MalformedArnException extends AwsCryptoException { + + private static final long serialVersionUID = -1L; + + public MalformedArnException() { + super(); + } + + public MalformedArnException(final String message) { + super(message); + } + + public MalformedArnException(final Throwable cause) { + super(cause); + } + + public MalformedArnException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java new file mode 100644 index 000000000..46a062187 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.exception; + +/** + * This exception is thrown when the key used by KMS to decrypt a data key does not + * match the provider information contained within the encrypted data key. + */ +public class MismatchedDataKeyException extends AwsCryptoException { + + private static final long serialVersionUID = -1L; + + public MismatchedDataKeyException() { + super(); + } + + public MismatchedDataKeyException(final String message) { + super(message); + } + + public MismatchedDataKeyException(final Throwable cause) { + super(cause); + } + + public MismatchedDataKeyException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedRegionException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedRegionException.java new file mode 100644 index 000000000..2c06a602c --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedRegionException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.exception; + +/** + * This exception is thrown when a region that is not allowed to be used by + * a given KmsClientSupplier is specified. + */ +public class UnsupportedRegionException extends AwsCryptoException { + + private static final long serialVersionUID = -1L; + + public UnsupportedRegionException() { + super(); + } + + public UnsupportedRegionException(final String message) { + super(message); + } + + public UnsupportedRegionException(final Throwable cause) { + super(cause); + } + + public UnsupportedRegionException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java new file mode 100644 index 000000000..4b3974501 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java @@ -0,0 +1,176 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.keyrings; + +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.GenerateDataKeyResult; +import com.amazonaws.encryptionsdk.kms.KmsUtils; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.isArnWellFormed; +import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +/** + * A keyring which interacts with AWS Key Management Service (KMS) to create, + * encrypt, and decrypt data keys using KMS defined Customer Master Keys (CMKs). + */ +class KmsKeyring implements Keyring { + + private final DataKeyEncryptionDao dataKeyEncryptionDao; + private final List keyIds; + private final String generatorKeyId; + private final boolean isDiscovery; + + KmsKeyring(DataKeyEncryptionDao dataKeyEncryptionDao, List keyIds, String generatorKeyId) { + requireNonNull(dataKeyEncryptionDao, "dataKeyEncryptionDao is required"); + this.dataKeyEncryptionDao = dataKeyEncryptionDao; + this.keyIds = keyIds == null ? emptyList() : unmodifiableList(new ArrayList<>(keyIds)); + this.generatorKeyId = generatorKeyId; + this.isDiscovery = this.generatorKeyId == null && this.keyIds.isEmpty(); + + if (!this.keyIds.stream().allMatch(KmsUtils::isArnWellFormed)) { + throw new MalformedArnException("keyIds must contain only CMK aliases and well formed ARNs"); + } + + if (generatorKeyId != null) { + if (!isArnWellFormed(generatorKeyId)) { + throw new MalformedArnException("generatorKeyId must be either a CMK alias or a well formed ARN"); + } + if (this.keyIds.contains(generatorKeyId)) { + throw new IllegalArgumentException("KeyIds should not contain the generatorKeyId"); + } + } + } + + @Override + public void onEncrypt(EncryptionMaterials encryptionMaterials) { + requireNonNull(encryptionMaterials, "encryptionMaterials are required"); + + // If this keyring is a discovery keyring, OnEncrypt MUST return the input encryption materials unmodified. + if (isDiscovery) { + return; + } + + // If the input encryption materials do not contain a plaintext data key and this keyring does not + // have a generator defined, OnEncrypt MUST not modify the encryption materials and MUST fail. + if (!encryptionMaterials.hasPlaintextDataKey() && generatorKeyId == null) { + throw new AwsCryptoException("Encryption materials must contain either a plaintext data key or a generator"); + } + + final List keyIdsToEncrypt = new ArrayList<>(keyIds); + + // If the input encryption materials do not contain a plaintext data key and a generator is defined onEncrypt + // MUST attempt to generate a new plaintext data key and encrypt that data key by calling KMS GenerateDataKey. + if (!encryptionMaterials.hasPlaintextDataKey()) { + generateDataKey(encryptionMaterials); + } else if (generatorKeyId != null) { + // If this keyring's generator is defined and was not used to generate a data key, OnEncrypt + // MUST also attempt to encrypt the plaintext data key using the CMK specified by the generator. + keyIdsToEncrypt.add(generatorKeyId); + } + + // Given a plaintext data key in the encryption materials, OnEncrypt MUST attempt + // to encrypt the plaintext data key using each CMK specified in it's key IDs list. + for (String keyId : keyIdsToEncrypt) { + encryptDataKey(keyId, encryptionMaterials); + } + } + + private void generateDataKey(final EncryptionMaterials encryptionMaterials) { + final GenerateDataKeyResult result = dataKeyEncryptionDao.generateDataKey(generatorKeyId, + encryptionMaterials.getAlgorithmSuite(), encryptionMaterials.getEncryptionContext()); + + encryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(), + new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.GENERATED_DATA_KEY)); + encryptionMaterials.addEncryptedDataKey(result.getEncryptedDataKey(), + new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT)); + } + + private void encryptDataKey(final String keyId, final EncryptionMaterials encryptionMaterials) { + final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao.encryptDataKey(keyId, + encryptionMaterials.getPlaintextDataKey(), encryptionMaterials.getEncryptionContext()); + + encryptionMaterials.addEncryptedDataKey(encryptedDataKey, + new KeyringTraceEntry(KMS_PROVIDER_ID, keyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT)); + } + + @Override + public void onDecrypt(DecryptionMaterials decryptionMaterials, List encryptedDataKeys) { + requireNonNull(decryptionMaterials, "decryptionMaterials are required"); + requireNonNull(encryptedDataKeys, "encryptedDataKeys are required"); + + if (decryptionMaterials.hasPlaintextDataKey() || encryptedDataKeys.isEmpty()) { + return; + } + + final Set configuredKeyIds = new HashSet<>(keyIds); + + if (generatorKeyId != null) { + configuredKeyIds.add(generatorKeyId); + } + + for (EncryptedDataKey encryptedDataKey : encryptedDataKeys) { + if (okToDecrypt(encryptedDataKey, configuredKeyIds)) { + try { + final DecryptDataKeyResult result = dataKeyEncryptionDao.decryptDataKey(encryptedDataKey, + decryptionMaterials.getAlgorithmSuite(), decryptionMaterials.getEncryptionContext()); + + decryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(), + new KeyringTraceEntry(KMS_PROVIDER_ID, result.getKeyArn(), + KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT)); + return; + } catch (CannotUnwrapDataKeyException e) { + continue; + } + } + } + } + + private boolean okToDecrypt(EncryptedDataKey encryptedDataKey, Set configuredKeyIds) { + // Only attempt to decrypt keys provided by KMS + if (!encryptedDataKey.getProviderId().equals(KMS_PROVIDER_ID)) { + return false; + } + + // If the key ARN cannot be parsed, skip it + if(!isArnWellFormed(new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING))) + { + return false; + } + + // If this keyring is a discovery keyring, OnDecrypt MUST attempt to + // decrypt every encrypted data key in the input encrypted data key list + if (isDiscovery) { + return true; + } + + // OnDecrypt MUST attempt to decrypt each input encrypted data key in the input + // encrypted data key list where the key provider info has a value equal to one + // of the ARNs in this keyring's key IDs or the generator + return configuredKeyIds.contains(new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING)); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java index 34b7e22f4..9dc96cbf8 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/StandardKeyrings.java @@ -13,6 +13,9 @@ package com.amazonaws.encryptionsdk.keyrings; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.KmsClientSupplier; + import javax.crypto.SecretKey; import java.security.PrivateKey; import java.security.PublicKey; @@ -56,6 +59,22 @@ public static Keyring rawRsa(String keyNamespace, String keyName, PublicKey publ return new RawRsaKeyring(keyNamespace, keyName, publicKey, privateKey, wrappingAlgorithm); } + /** + * Constructs a {@code Keyring} which interacts with AWS Key Management Service (KMS) to create, + * encrypt, and decrypt data keys using KMS defined Customer Master Keys (CMKs). + * + * @param clientSupplier A function that returns a KMS client that can make GenerateDataKey, + * Encrypt, and Decrypt calls in a particular AWS region. + * @param grantTokens A list of string grant tokens to be included in all KMS calls. + * @param keyIds A list of strings identifying KMS CMKs, in ARN, CMK Alias, or ARN Alias format. + * @param generator A string that identifies a KMS CMK responsible for generating a data key, + * as well as encrypting and decrypting data keys in ARN, CMK Alias, or ARN Alias format. + * @return The {@code Keyring} + */ + public static Keyring kms(KmsClientSupplier clientSupplier, List grantTokens, List keyIds, String generator) { + return new KmsKeyring(DataKeyEncryptionDao.kms(clientSupplier, grantTokens), keyIds, generator); + } + /** * Constructs a {@code Keyring} which combines other keyrings, allowing one OnEncrypt or OnDecrypt call * to modify the encryption or decryption materials using more than one keyring. diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java new file mode 100644 index 000000000..4267ba6f8 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java @@ -0,0 +1,104 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; + +import javax.crypto.SecretKey; +import java.util.List; +import java.util.Map; + +public interface DataKeyEncryptionDao { + + /** + * Generates a unique data key, returning both the plaintext copy of the key and an encrypted copy encrypted using + * the customer master key specified by the given keyId. + * + * @param keyId The customer master key to encrypt the generated key with. + * @param algorithmSuite The algorithm suite associated with the key. + * @param encryptionContext The encryption context. + * @return GenerateDataKeyResult containing the plaintext data key and the encrypted data key. + */ + GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext); + + /** + * Encrypts the given plaintext data key using the customer aster key specified by the given keyId. + * + * @param keyId The customer master key to encrypt the plaintext data key with. + * @param plaintextDataKey The plaintext data key to encrypt. + * @param encryptionContext The encryption context. + * @return The encrypted data key. + */ + EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext); + + /** + * Decrypted the given encrypted data key. + * + * @param encryptedDataKey The encrypted data key to decrypt. + * @param algorithmSuite The algorithm suite associated with the key. + * @param encryptionContext The encryption context. + * @return DecryptDataKeyResult containing the plaintext data key and the ARN of the key that decrypted it. + */ + DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext); + + /** + * Constructs an instance of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for + * generation, encryption, and decryption of data keys. + * + * @param clientSupplier A supplier of AWSKMS clients + * @param grantTokens A list of grant tokens to supply to KMS + * @return The DataKeyEncryptionDao + */ + static DataKeyEncryptionDao kms(KmsClientSupplier clientSupplier, List grantTokens) { + return new KmsDataKeyEncryptionDao(clientSupplier, grantTokens); + } + + class GenerateDataKeyResult { + private final SecretKey plaintextDataKey; + private final EncryptedDataKey encryptedDataKey; + + public GenerateDataKeyResult(SecretKey plaintextDataKey, EncryptedDataKey encryptedDataKey) { + this.plaintextDataKey = plaintextDataKey; + this.encryptedDataKey = encryptedDataKey; + } + + public SecretKey getPlaintextDataKey() { + return plaintextDataKey; + } + + public EncryptedDataKey getEncryptedDataKey() { + return encryptedDataKey; + } + } + + class DecryptDataKeyResult { + private final String keyArn; + private final SecretKey plaintextDataKey; + + public DecryptDataKeyResult(String keyArn, SecretKey plaintextDataKey) { + this.keyArn = keyArn; + this.plaintextDataKey = plaintextDataKey; + } + + public String getKeyArn() { + return keyArn; + } + + public SecretKey getPlaintextDataKey() { + return plaintextDataKey; + } + + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java new file mode 100644 index 000000000..6182e161c --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java @@ -0,0 +1,212 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClientBuilder; +import com.amazonaws.services.kms.model.AWSKMSException; + +import javax.annotation.Nullable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Proxy; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.isTrue; +import static org.apache.commons.lang3.Validate.notEmpty; + +/** + * Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The + * function should be able to handle when the region is null. + */ +@FunctionalInterface +public interface KmsClientSupplier { + + /** + * Gets an {@code AWSKMS} client for the given regionId. + * + * @param regionId The AWS region (or null) + * @return The AWSKMS client + * @throws UnsupportedRegionException if a regionId is specified that this + * client supplier is configured to not allow. + */ + AWSKMS getClient(@Nullable String regionId) throws UnsupportedRegionException; + + /** + * Gets a Builder for constructing a KmsClientSupplier + * + * @return The builder + */ + static Builder builder() { + return new Builder(AWSKMSClientBuilder.standard()); + } + + /** + * Builder to construct a KmsClientSupplier given various + * optional settings. + */ + class Builder { + + private AWSCredentialsProvider credentialsProvider; + private ClientConfiguration clientConfiguration; + private Set allowedRegions = Collections.emptySet(); + private Set excludedRegions = Collections.emptySet(); + private boolean clientCachingEnabled = false; + private final Map clientsCache = new HashMap<>(); + private static final Set KMS_METHODS = new HashSet<>(); + private AWSKMSClientBuilder kmsClientBuilder; + + static { + KMS_METHODS.add("generateDataKey"); + KMS_METHODS.add("encrypt"); + KMS_METHODS.add("decrypt"); + } + + Builder(AWSKMSClientBuilder kmsClientBuilder) { + this.kmsClientBuilder = kmsClientBuilder; + } + + public KmsClientSupplier build() { + isTrue(allowedRegions.isEmpty() || excludedRegions.isEmpty(), + "Either allowed regions or excluded regions may be set, not both."); + + return regionId -> { + if (!allowedRegions.isEmpty() && !allowedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is not in the list of allowed regions %s", + regionId, allowedRegions)); + } + + if (excludedRegions.contains(regionId)) { + throw new UnsupportedRegionException(String.format("Region %s is in the list of excluded regions %s", + regionId, excludedRegions)); + } + + if (clientsCache.containsKey(regionId)) { + return clientsCache.get(regionId); + } + + if (credentialsProvider != null) { + kmsClientBuilder = kmsClientBuilder.withCredentials(credentialsProvider); + } + + if (clientConfiguration != null) { + kmsClientBuilder = kmsClientBuilder.withClientConfiguration(clientConfiguration); + } + + if (regionId != null) { + kmsClientBuilder = kmsClientBuilder.withRegion(regionId); + } + + AWSKMS client = kmsClientBuilder.build(); + + if (clientCachingEnabled) { + client = newCachingProxy(client, regionId); + } + + return client; + }; + } + + /** + * Sets the AWSCredentialsProvider used by the client. + * + * @param credentialsProvider New AWSCredentialsProvider to use. + */ + public Builder credentialsProvider(AWSCredentialsProvider credentialsProvider) { + this.credentialsProvider = credentialsProvider; + return this; + } + + /** + * Sets the ClientConfiguration to be used by the client. + * + * @param clientConfiguration Custom configuration to use. + */ + public Builder clientConfiguration(ClientConfiguration clientConfiguration) { + this.clientConfiguration = clientConfiguration; + return this; + } + + /** + * Sets the AWS regions that the client supplier is permitted to use. + * + * @param regions The set of regions. + */ + public Builder allowedRegions(Set regions) { + notEmpty(regions, "At least one region is required"); + this.allowedRegions = Collections.unmodifiableSet(new HashSet<>(regions)); + return this; + } + + /** + * Sets the AWS regions that the client supplier is not permitted to use. + * + * @param regions The set of regions. + */ + public Builder excludedRegions(Set regions) { + requireNonNull(regions, "regions is required"); + this.excludedRegions = Collections.unmodifiableSet(new HashSet<>(regions)); + return this; + } + + /** + * When set to true, allows for the AWSKMS client for each region to be cached and reused. + * + * @param enabled Whether or not caching is enabled. + */ + public Builder clientCaching(boolean enabled) { + this.clientCachingEnabled = enabled; + return this; + } + + /** + * Creates a proxy for the AWSKMS client that will populate the client into the client cache + * after a KMS method successfully completes or a KMS exception occurs. This is to prevent a + * a malicious user from causing a local resource DOS by sending ciphertext with a large number + * of spurious regions, thereby filling the cache with regions and exhausting resources. + * + * @param client The client to proxy + * @param regionId The region the client is associated with + * @return The proxy + */ + private AWSKMS newCachingProxy(AWSKMS client, String regionId) { + return (AWSKMS) Proxy.newProxyInstance( + AWSKMS.class.getClassLoader(), + new Class[]{AWSKMS.class}, + (proxy, method, methodArgs) -> { + try { + final Object result = method.invoke(client, methodArgs); + if (KMS_METHODS.contains(method.getName())) { + clientsCache.put(regionId, client); + } + return result; + } catch (InvocationTargetException e) { + if (e.getTargetException() instanceof AWSKMSException && + KMS_METHODS.contains(method.getName())) { + clientsCache.put(regionId, client); + } + + throw e.getTargetException(); + } + }); + } + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java new file mode 100644 index 000000000..a00a1f1c7 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java @@ -0,0 +1,172 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.AmazonServiceException; +import com.amazonaws.AmazonWebServiceRequest; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.encryptionsdk.internal.VersionInfo; +import com.amazonaws.encryptionsdk.model.KeyBlob; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.getClientByArn; +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.isTrue; + +/** + * An implementation of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for + * generation, encryption, and decryption of data keys. The KmsMethods interface is implemented + * to allow usage in KmsMasterKey. + */ +class KmsDataKeyEncryptionDao implements DataKeyEncryptionDao, KmsMethods { + + private final KmsClientSupplier clientSupplier; + private List grantTokens; + + KmsDataKeyEncryptionDao(KmsClientSupplier clientSupplier, List grantTokens) { + requireNonNull(clientSupplier, "clientSupplier is required"); + + this.clientSupplier = clientSupplier; + this.grantTokens = grantTokens == null ? new ArrayList<>() : new ArrayList<>(grantTokens); + } + + @Override + public GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext) { + requireNonNull(keyId, "keyId is required"); + requireNonNull(algorithmSuite, "algorithmSuite is required"); + requireNonNull(encryptionContext, "encryptionContext is required"); + + final com.amazonaws.services.kms.model.GenerateDataKeyResult kmsResult; + + try { + kmsResult = getClientByArn(keyId, clientSupplier) + .generateDataKey(updateUserAgent( + new GenerateDataKeyRequest() + .withKeyId(keyId) + .withNumberOfBytes(algorithmSuite.getDataKeyLength()) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens))); + } catch (final AmazonServiceException ex) { + throw new AwsCryptoException(ex); + } + + final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()]; + kmsResult.getPlaintext().get(rawKey); + if (kmsResult.getPlaintext().remaining() > 0) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); + } + final byte[] encryptedKey = new byte[kmsResult.getCiphertextBlob().remaining()]; + kmsResult.getCiphertextBlob().get(encryptedKey); + + return new GenerateDataKeyResult(new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo()), + new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedKey)); + } + + @Override + public EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext) { + requireNonNull(keyId, "keyId is required"); + requireNonNull(plaintextDataKey, "plaintextDataKey is required"); + requireNonNull(encryptionContext, "encryptionContext is required"); + isTrue(plaintextDataKey.getFormat().equals("RAW"), "Only RAW encoded keys are supported"); + + final com.amazonaws.services.kms.model.EncryptResult kmsResult; + + try { + kmsResult = getClientByArn(keyId, clientSupplier) + .encrypt(updateUserAgent(new EncryptRequest() + .withKeyId(keyId) + .withPlaintext(ByteBuffer.wrap(plaintextDataKey.getEncoded())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens))); + } catch (final AmazonServiceException ex) { + throw new AwsCryptoException(ex); + } + final byte[] encryptedDataKey = new byte[kmsResult.getCiphertextBlob().remaining()]; + kmsResult.getCiphertextBlob().get(encryptedDataKey); + + return new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedDataKey); + + } + + @Override + public DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext) { + requireNonNull(encryptedDataKey, "encryptedDataKey is required"); + requireNonNull(algorithmSuite, "algorithmSuite is required"); + requireNonNull(encryptionContext, "encryptionContext is required"); + + final String providerInformation = new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING); + final com.amazonaws.services.kms.model.DecryptResult kmsResult; + + try { + kmsResult = getClientByArn(providerInformation, clientSupplier) + .decrypt(updateUserAgent(new DecryptRequest() + .withCiphertextBlob(ByteBuffer.wrap(encryptedDataKey.getEncryptedDataKey())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens))); + } catch (final AmazonServiceException | UnsupportedRegionException ex) { + throw new CannotUnwrapDataKeyException(ex); + } + + if (!kmsResult.getKeyId().equals(providerInformation)) { + throw new MismatchedDataKeyException("Received an unexpected key Id from KMS"); + } + + final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()]; + kmsResult.getPlaintext().get(rawKey); + if (kmsResult.getPlaintext().remaining() > 0) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); + } + + return new DecryptDataKeyResult(kmsResult.getKeyId(), new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo())); + + } + + private T updateUserAgent(T request) { + request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT); + + return request; + } + + @Override + public void setGrantTokens(List grantTokens) { + this.grantTokens = new ArrayList<>(grantTokens); + } + + @Override + public List getGrantTokens() { + return Collections.unmodifiableList(grantTokens); + } + + @Override + public void addGrantToken(String grantToken) { + grantTokens.add(grantToken); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java index b78840221..60c69445c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java @@ -14,17 +14,12 @@ package com.amazonaws.encryptionsdk.kms; import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.function.Supplier; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.AmazonWebServiceRequest; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.encryptionsdk.AwsCrypto; @@ -35,30 +30,18 @@ import com.amazonaws.encryptionsdk.MasterKeyProvider; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; -import com.amazonaws.encryptionsdk.internal.VersionInfo; import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.model.DecryptRequest; -import com.amazonaws.services.kms.model.DecryptResult; -import com.amazonaws.services.kms.model.EncryptRequest; -import com.amazonaws.services.kms.model.EncryptResult; -import com.amazonaws.services.kms.model.GenerateDataKeyRequest; -import com.amazonaws.services.kms.model.GenerateDataKeyResult; + +import static java.util.Collections.emptyList; /** * Represents a single Customer Master Key (CMK) and is used to encrypt/decrypt data with * {@link AwsCrypto}. */ public final class KmsMasterKey extends MasterKey implements KmsMethods { - private final Supplier kms_; + private final KmsDataKeyEncryptionDao dataKeyEncryptionDao_; private final MasterKeyProvider sourceProvider_; private final String id_; - private final List grantTokens_ = new ArrayList<>(); - - private T updateUserAgent(T request) { - request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT); - - return request; - } /** * @@ -80,11 +63,11 @@ public static KmsMasterKey getInstance(final AWSCredentialsProvider creds, final static KmsMasterKey getInstance(final Supplier kms, final String id, final MasterKeyProvider provider) { - return new KmsMasterKey(kms, id, provider); + return new KmsMasterKey(new KmsDataKeyEncryptionDao(s -> kms.get(), emptyList()), id, provider); } - private KmsMasterKey(final Supplier kms, final String id, final MasterKeyProvider provider) { - kms_ = kms; + KmsMasterKey(final KmsDataKeyEncryptionDao dataKeyEncryptionDao, final String id, final MasterKeyProvider provider) { + dataKeyEncryptionDao_ = dataKeyEncryptionDao; id_ = id; sourceProvider_ = provider; } @@ -102,39 +85,27 @@ public String getKeyId() { @Override public DataKey generateDataKey(final CryptoAlgorithm algorithm, final Map encryptionContext) { - final GenerateDataKeyResult gdkResult = kms_.get().generateDataKey(updateUserAgent( - new GenerateDataKeyRequest() - .withKeyId(getKeyId()) - .withNumberOfBytes(algorithm.getDataKeyLength()) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_) - )); - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - gdkResult.getPlaintext().get(rawKey); - if (gdkResult.getPlaintext().remaining() > 0) { - throw new IllegalStateException("Recieved an unexpected number of bytes from KMS"); - } - final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()]; - gdkResult.getCiphertextBlob().get(encryptedKey); - - final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()); - return new DataKey<>(key, encryptedKey, gdkResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); + final DataKeyEncryptionDao.GenerateDataKeyResult gdkResult = dataKeyEncryptionDao_.generateDataKey( + getKeyId(), algorithm, encryptionContext); + return new DataKey<>(gdkResult.getPlaintextDataKey(), + gdkResult.getEncryptedDataKey().getEncryptedDataKey(), + gdkResult.getEncryptedDataKey().getProviderInformation(), + this); } @Override public void setGrantTokens(final List grantTokens) { - grantTokens_.clear(); - grantTokens_.addAll(grantTokens); + dataKeyEncryptionDao_.setGrantTokens(grantTokens); } @Override public List getGrantTokens() { - return grantTokens_; + return dataKeyEncryptionDao_.getGrantTokens(); } @Override public void addGrantToken(final String grantToken) { - grantTokens_.add(grantToken); + dataKeyEncryptionDao_.addGrantToken(grantToken); } @Override @@ -142,22 +113,12 @@ public DataKey encryptDataKey(final CryptoAlgorithm algorithm, final Map encryptionContext, final DataKey dataKey) { final SecretKey key = dataKey.getKey(); - if (!key.getFormat().equals("RAW")) { - throw new IllegalArgumentException("Only RAW encoded keys are supported"); - } - try { - final EncryptResult encryptResult = kms_.get().encrypt(updateUserAgent( - new EncryptRequest() - .withKeyId(id_) - .withPlaintext(ByteBuffer.wrap(key.getEncoded())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_))); - final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()]; - encryptResult.getCiphertextBlob().get(edk); - return new DataKey<>(dataKey.getKey(), edk, encryptResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); - } catch (final AmazonServiceException asex) { - throw new AwsCryptoException(asex); - } + final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao_.encryptDataKey(id_, key, encryptionContext); + + return new DataKey<>(dataKey.getKey(), + encryptedDataKey.getEncryptedDataKey(), + encryptedDataKey.getProviderInformation(), + this); } @Override @@ -168,24 +129,13 @@ public DataKey decryptDataKey(final CryptoAlgorithm algorithm, final List exceptions = new ArrayList<>(); for (final EncryptedDataKey edk : encryptedDataKeys) { try { - final DecryptResult decryptResult = kms_.get().decrypt(updateUserAgent( - new DecryptRequest() - .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_))); - if (decryptResult.getKeyId().equals(id_)) { - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - decryptResult.getPlaintext().get(rawKey); - if (decryptResult.getPlaintext().remaining() > 0) { - throw new IllegalStateException("Received an unexpected number of bytes from KMS"); - } - return new DataKey<>( - new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), - edk.getEncryptedDataKey(), - edk.getProviderInformation(), this); - } - } catch (final AmazonServiceException awsex) { - exceptions.add(awsex); + final DataKeyEncryptionDao.DecryptDataKeyResult result = dataKeyEncryptionDao_.decryptDataKey(edk, algorithm, encryptionContext); + return new DataKey<>( + result.getPlaintextDataKey(), + edk.getEncryptedDataKey(), + edk.getProviderInformation(), this); + } catch (final AwsCryptoException ex) { + exceptions.add(ex); } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java new file mode 100644 index 000000000..f6aab16eb --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java @@ -0,0 +1,82 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.arn.Arn; +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.services.kms.AWSKMS; + +public class KmsUtils { + + private static final String ALIAS_PREFIX = "alias/"; + private static final String ARN_PREFIX = "arn:"; + /** + * The provider ID used for the KmsKeyring + */ + public static final String KMS_PROVIDER_ID = "aws-kms"; + + /** + * Parses region from the given arn (if possible) and passes that region to the + * given clientSupplier to produce an {@code AWSKMS} client. + * + * @param arn The Amazon Resource Name or Key Alias + * @param clientSupplier The client supplier + * @return AWSKMS The client + * @throws MalformedArnException if the arn is malformed + */ + public static AWSKMS getClientByArn(String arn, KmsClientSupplier clientSupplier) throws MalformedArnException { + if (isKeyAlias(arn)) { + return clientSupplier.getClient(null); + } + + if(isArn(arn)) { + try { + return clientSupplier.getClient(Arn.fromString(arn).getRegion()); + } catch (IllegalArgumentException e) { + throw new MalformedArnException(e); + } + } + + // Not an alias or an ARN, must be a raw Key ID + return clientSupplier.getClient(null); + } + + /** + * Returns true if the given arn is a well formed Amazon Resource Name or Key Alias. Does + * not return true for raw key IDs. + * + * @param arn The Amazon Resource Name or Key Alias + * @return True if well formed, false otherwise + */ + public static boolean isArnWellFormed(String arn) { + if (isKeyAlias(arn)) { + return true; + } + + try { + Arn.fromString(arn); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } + + private static boolean isKeyAlias(String arn) { + return arn.startsWith(ALIAS_PREFIX); + } + + private static boolean isArn(String arn) { + return arn.startsWith(ARN_PREFIX); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java new file mode 100644 index 000000000..6ed3629ff --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java @@ -0,0 +1,334 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.keyrings; + +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult; +import com.amazonaws.encryptionsdk.model.KeyBlob; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.ENCRYPTED_DATA_KEY; +import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.GENERATED_DATA_KEY; +import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KmsKeyringTest { + + private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + private static final SecretKey PLAINTEXT_DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + private static final String GENERATOR_KEY_ID = "arn:aws:kms:us-east-1:999999999999:key/generator-89ab-cdef-fedc-ba9876543210"; + private static final String KEY_ID_1 = "arn:aws:kms:us-east-1:999999999999:key/key1-23bv-sdfs-werw-234323nfdsf"; + private static final String KEY_ID_2 = "arn:aws:kms:us-east-1:999999999999:key/key2-02ds-wvjs-aswe-a4923489273"; + private static final EncryptedDataKey ENCRYPTED_GENERATOR_KEY = new KeyBlob(KMS_PROVIDER_ID, + GENERATOR_KEY_ID.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey ENCRYPTED_KEY_1 = new KeyBlob(KMS_PROVIDER_ID, + KEY_ID_1.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey ENCRYPTED_KEY_2 = new KeyBlob(KMS_PROVIDER_ID, + KEY_ID_2.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final KeyringTraceEntry ENCRYPTED_GENERATOR_KEY_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT); + private static final KeyringTraceEntry ENCRYPTED_KEY_1_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT); + private static final KeyringTraceEntry ENCRYPTED_KEY_2_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT); + private static final KeyringTraceEntry GENERATED_DATA_KEY_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, GENERATED_DATA_KEY); + @Mock(lenient = true) private DataKeyEncryptionDao dataKeyEncryptionDao; + private Keyring keyring; + + @BeforeEach + void setup() { + when(dataKeyEncryptionDao.encryptDataKey(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_GENERATOR_KEY); + when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_1, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_1); + when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_2, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_2); + + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_GENERATOR_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DecryptDataKeyResult(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY)); + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DecryptDataKeyResult(KEY_ID_1, PLAINTEXT_DATA_KEY)); + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_2, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DecryptDataKeyResult(KEY_ID_2, PLAINTEXT_DATA_KEY)); + + List keyIds = new ArrayList<>(); + keyIds.add(KEY_ID_1); + keyIds.add(KEY_ID_2); + keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, GENERATOR_KEY_ID); + } + + @Test + void testMalformedArns() { + assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, null, "badArn")); + assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList("badArn"), GENERATOR_KEY_ID)); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(new KeyBlob(KMS_PROVIDER_ID, "badArn".getBytes(PROVIDER_ENCODING), new byte[]{})); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + // Malformed Arn for a non KMS provider shouldn't fail + encryptedDataKeys.clear(); + encryptedDataKeys.add(new KeyBlob("OtherProviderId", "badArn".getBytes(PROVIDER_ENCODING), new byte[]{})); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + } + + @Test + void testGeneratorKeyInKeyIds() { + assertThrows(IllegalArgumentException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList(GENERATOR_KEY_ID), GENERATOR_KEY_ID)); + } + + @Test + void testEncryptDecryptExistingDataKey() { + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .plaintextDataKey(PLAINTEXT_DATA_KEY) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + keyring.onEncrypt(encryptionMaterials); + + assertEquals(3, encryptionMaterials.getEncryptedDataKeys().size()); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_GENERATOR_KEY)); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_1)); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_2)); + + assertEquals(3, encryptionMaterials.getKeyringTrace().getEntries().size()); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE)); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testEncryptNullDataKey() { + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .keyringTrace(new KeyringTrace()) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + when(dataKeyEncryptionDao.generateDataKey(GENERATOR_KEY_ID, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenReturn(new DataKeyEncryptionDao.GenerateDataKeyResult(PLAINTEXT_DATA_KEY, ENCRYPTED_GENERATOR_KEY)); + keyring.onEncrypt(encryptionMaterials); + + assertEquals(PLAINTEXT_DATA_KEY, encryptionMaterials.getPlaintextDataKey()); + + assertEquals(4, encryptionMaterials.getKeyringTrace().getEntries().size()); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(GENERATED_DATA_KEY_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE)); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testEncryptNullGenerator() { + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .keyringTrace(new KeyringTrace()) + .plaintextDataKey(PLAINTEXT_DATA_KEY) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + Keyring keyring = new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList(KEY_ID_1), null); + + keyring.onEncrypt(encryptionMaterials); + + assertEquals(1, encryptionMaterials.getEncryptedDataKeys().size()); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_1)); + + assertEquals(PLAINTEXT_DATA_KEY, encryptionMaterials.getPlaintextDataKey()); + + assertEquals(1, encryptionMaterials.getKeyringTrace().getEntries().size()); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE)); + } + + @Test + void testDiscoveryEncrypt() { + keyring = new KmsKeyring(dataKeyEncryptionDao, null, null); + + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + keyring.onEncrypt(encryptionMaterials); + + assertFalse(encryptionMaterials.hasPlaintextDataKey()); + assertEquals(0, encryptionMaterials.getKeyringTrace().getEntries().size()); + } + + @Test + void testEncryptNoGeneratorOrPlaintextDataKey() { + List keyIds = new ArrayList<>(); + keyIds.add(KEY_ID_1); + keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, null); + + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE).build(); + assertThrows(AwsCryptoException.class, () -> keyring.onEncrypt(encryptionMaterials)); + } + + @Test + void testDecryptFirstKeyFails() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenThrow(new CannotUnwrapDataKeyException()); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDecryptMismatchedDataKeyException() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenThrow(new MismatchedDataKeyException()); + + assertThrows(MismatchedDataKeyException.class, () -> keyring.onDecrypt(decryptionMaterials, Collections.singletonList(ENCRYPTED_KEY_1))); + } + + @Test + void testDecryptFirstKeyWrongProvider() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + EncryptedDataKey wrongProviderKey = new KeyBlob("OtherProvider", KEY_ID_1.getBytes(PROVIDER_ENCODING), new byte[]{}); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(wrongProviderKey); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDiscoveryDecrypt() { + keyring = new KmsKeyring(dataKeyEncryptionDao, null, null); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDecryptAlreadyDecryptedDataKey() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .plaintextDataKey(PLAINTEXT_DATA_KEY) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + keyring.onDecrypt(decryptionMaterials, Collections.singletonList(ENCRYPTED_GENERATOR_KEY)); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size()); + } + + @Test + void testDecryptNoDataKey() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + keyring.onDecrypt(decryptionMaterials, Collections.emptyList()); + + assertFalse(decryptionMaterials.hasPlaintextDataKey()); + assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size()); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplierTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplierTest.java new file mode 100644 index 000000000..cdca99fd6 --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplierTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClientBuilder; +import com.amazonaws.services.kms.model.AWSKMSException; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.EncryptResult; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; +import com.amazonaws.services.kms.model.KMSInvalidStateException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KmsClientSupplierTest { + + @Mock AWSKMSClientBuilder kmsClientBuilder; + @Mock AWSKMS awskms; + @Mock EncryptRequest encryptRequest; + @Mock DecryptRequest decryptRequest; + @Mock GenerateDataKeyRequest generateDataKeyRequest; + @Mock AWSCredentialsProvider credentialsProvider; + @Mock ClientConfiguration clientConfiguration; + private static final String REGION_1 = "us-east-1"; + private static final String REGION_2 = "us-west-2"; + private static final String REGION_3 = "eu-west-1"; + + @Test + void testCredentialsAndClientConfiguration() { + when(kmsClientBuilder.withClientConfiguration(clientConfiguration)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.withCredentials(credentialsProvider)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + KmsClientSupplier supplier = new KmsClientSupplier.Builder(kmsClientBuilder) + .credentialsProvider(credentialsProvider) + .clientConfiguration(clientConfiguration) + .build(); + + supplier.getClient(null); + + verify(kmsClientBuilder).withClientConfiguration(clientConfiguration); + verify(kmsClientBuilder).withCredentials(credentialsProvider); + verify(kmsClientBuilder).build(); + } + + @Test + void testAllowedAndExcludedRegions() { + KmsClientSupplier supplierWithDefaultValues = new KmsClientSupplier.Builder(kmsClientBuilder) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithDefaultValues.getClient(REGION_1)); + + KmsClientSupplier supplierWithAllowed = new KmsClientSupplier.Builder(kmsClientBuilder) + .allowedRegions(Collections.singleton(REGION_1)) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertNotNull(supplierWithAllowed.getClient(REGION_1)); + assertThrows(UnsupportedRegionException.class, () -> supplierWithAllowed.getClient(REGION_2)); + + KmsClientSupplier supplierWithExcluded = new KmsClientSupplier.Builder(kmsClientBuilder) + .excludedRegions(Collections.singleton(REGION_1)) + .build(); + + when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + assertThrows(UnsupportedRegionException.class, () -> supplierWithExcluded.getClient(REGION_1)); + assertNotNull(supplierWithExcluded.getClient(REGION_2)); + + assertThrows(IllegalArgumentException.class, () -> new KmsClientSupplier.Builder(kmsClientBuilder) + .allowedRegions(Collections.singleton(REGION_1)) + .excludedRegions(Collections.singleton(REGION_2)) + .build()); + } + + @Test + void testClientCachingDisabled() { + KmsClientSupplier supplierCachingDisabled = new KmsClientSupplier.Builder(kmsClientBuilder) + .clientCaching(false) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + AWSKMS uncachedClient = supplierCachingDisabled.getClient(REGION_1); + verify(kmsClientBuilder, times(1)).build(); + + when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult()); + + uncachedClient.encrypt(encryptRequest); + supplierCachingDisabled.getClient(REGION_1); + verify(kmsClientBuilder, times(2)).build(); + } + + @Test + void testClientCaching() { + KmsClientSupplier supplier = new KmsClientSupplier.Builder(kmsClientBuilder) + .clientCaching(true) + .build(); + + when(kmsClientBuilder.withRegion(REGION_1)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.withRegion(REGION_2)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.withRegion(REGION_3)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(awskms); + + AWSKMS client = supplier.getClient(REGION_1); + AWSKMS client2 = supplier.getClient(REGION_2); + AWSKMS client3 = supplier.getClient(REGION_3); + verify(kmsClientBuilder, times(3)).build(); + + // No KMS methods have been called yet, so clients remain uncached + supplier.getClient(REGION_1); + supplier.getClient(REGION_2); + supplier.getClient(REGION_3); + verify(kmsClientBuilder, times(6)).build(); + + when(awskms.encrypt(encryptRequest)).thenReturn(new EncryptResult()); + when(awskms.decrypt(decryptRequest)).thenThrow(new KMSInvalidStateException("test")); + when(awskms.generateDataKey(generateDataKeyRequest)).thenThrow(new IllegalArgumentException("test")); + + // Successful KMS call, client is cached + client.encrypt(encryptRequest); + supplier.getClient(REGION_1); + verify(kmsClientBuilder, times(6)).build(); + + // KMS call resulted in KMS exception, client is cached + assertThrows(AWSKMSException.class, () -> client2.decrypt(decryptRequest)); + supplier.getClient(REGION_2); + verify(kmsClientBuilder, times(6)).build(); + + // KMS call resulted in non-KMS exception, client is not cached + assertThrows(IllegalArgumentException.class, () -> client3.generateDataKey(generateDataKeyRequest)); + supplier.getClient(REGION_3); + verify(kmsClientBuilder, times(7)).build(); + + // Non-KMS method, client is not cached + client3.toString(); + supplier.getClient(REGION_3); + verify(kmsClientBuilder, times(8)).build(); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java new file mode 100644 index 000000000..e9dc140d7 --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java @@ -0,0 +1,243 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.AmazonWebServiceRequest; +import com.amazonaws.RequestClientOptions; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; +import com.amazonaws.encryptionsdk.exception.UnsupportedRegionException; +import com.amazonaws.encryptionsdk.internal.VersionInfo; +import com.amazonaws.encryptionsdk.model.KeyBlob; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.DecryptResult; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; +import com.amazonaws.services.kms.model.KMSInvalidStateException; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KmsDataKeyEncryptionDaoTest { + + private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + private static final SecretKey DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + private static final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + private static final EncryptedDataKey ENCRYPTED_DATA_KEY = new KeyBlob(KMS_PROVIDER_ID, + "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210".getBytes(EncryptedDataKey.PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + + @Test + void testEncryptAndDecrypt() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + EncryptedDataKey encryptedDataKeyResult = dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + EncryptRequest actualRequest = er.getValue(); + + assertEquals(keyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); + assertUserAgent(actualRequest); + + assertEquals(KMS_PROVIDER_ID, encryptedDataKeyResult.getProviderId()); + assertArrayEquals(keyId.getBytes(EncryptedDataKey.PROVIDER_ENCODING), encryptedDataKeyResult.getProviderInformation()); + assertNotNull(encryptedDataKeyResult.getEncryptedDataKey()); + + DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(encryptedDataKeyResult, ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); + assertArrayEquals(encryptedDataKeyResult.getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array()); + assertUserAgent(actualDecryptRequest); + + assertEquals(DATA_KEY, decryptDataKeyResult.getPlaintextDataKey()); + assertEquals(keyId, decryptDataKeyResult.getKeyArn()); + } + + @Test + void testGenerateAndDecrypt() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + DataKeyEncryptionDao.GenerateDataKeyResult generateDataKeyResult = dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor gr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gr.capture()); + + GenerateDataKeyRequest actualRequest = gr.getValue(); + + assertEquals(keyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes()); + assertUserAgent(actualRequest); + + assertNotNull(generateDataKeyResult.getPlaintextDataKey()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResult.getPlaintextDataKey().getEncoded().length); + assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResult.getPlaintextDataKey().getAlgorithm()); + assertNotNull(generateDataKeyResult.getEncryptedDataKey()); + + DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(generateDataKeyResult.getEncryptedDataKey(), ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); + assertArrayEquals(generateDataKeyResult.getEncryptedDataKey().getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array()); + assertUserAgent(actualDecryptRequest); + + assertEquals(generateDataKeyResult.getPlaintextDataKey(), decryptDataKeyResult.getPlaintextDataKey()); + assertEquals(keyId, decryptDataKeyResult.getKeyArn()); + } + + @Test + void testEncryptWithRawKeyId() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + String rawKeyId = keyId.split("/")[1]; + EncryptedDataKey encryptedDataKeyResult = dao.encryptDataKey(rawKeyId, DATA_KEY, ENCRYPTION_CONTEXT); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + EncryptRequest actualRequest = er.getValue(); + + assertEquals(rawKeyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); + assertUserAgent(actualRequest); + + assertEquals(KMS_PROVIDER_ID, encryptedDataKeyResult.getProviderId()); + assertArrayEquals(keyId.getBytes(EncryptedDataKey.PROVIDER_ENCODING), encryptedDataKeyResult.getProviderInformation()); + assertNotNull(encryptedDataKeyResult.getEncryptedDataKey()); + } + + @Test + void testEncryptWrongKeyFormat() { + SecretKey key = mock(SecretKey.class); + when(key.getFormat()).thenReturn("BadFormat"); + + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + + assertThrows(IllegalArgumentException.class, () -> dao.encryptDataKey(keyId, key, ENCRYPTION_CONTEXT)); + } + + @Test + void testKmsFailure() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + doThrow(new KMSInvalidStateException("fail")).when(client).generateDataKey(isA(GenerateDataKeyRequest.class)); + doThrow(new KMSInvalidStateException("fail")).when(client).encrypt(isA(EncryptRequest.class)); + doThrow(new KMSInvalidStateException("fail")).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(AwsCryptoException.class, () -> dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + assertThrows(AwsCryptoException.class, () -> dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT)); + assertThrows(AwsCryptoException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + void testUnsupportedRegionException() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + doThrow(new UnsupportedRegionException("fail")).when(client).generateDataKey(isA(GenerateDataKeyRequest.class)); + doThrow(new UnsupportedRegionException("fail")).when(client).encrypt(isA(EncryptRequest.class)); + doThrow(new UnsupportedRegionException("fail")).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(AwsCryptoException.class, () -> dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + assertThrows(AwsCryptoException.class, () -> dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT)); + assertThrows(AwsCryptoException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + void testDecryptBadKmsKeyId() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId("badKeyId"); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(MismatchedDataKeyException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + void testDecryptBadKmsKeyLength() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId(new String(ENCRYPTED_DATA_KEY.getProviderInformation(), EncryptedDataKey.PROVIDER_ENCODING)); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1)); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(IllegalStateException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + private void assertUserAgent(AmazonWebServiceRequest request) { + assertTrue(request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.USER_AGENT)); + } + +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java new file mode 100644 index 000000000..e4c731ee7 --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.DataKey; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; +import com.amazonaws.encryptionsdk.model.KeyBlob; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KmsMasterKeyTest { + + private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("test", "value"); + private static final String CMK_ARN = "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210"; + @Mock KmsDataKeyEncryptionDao dataKeyEncryptionDao; + + /** + * Test that when decryption of an encrypted data key throws a MismatchedDataKeyException, this + * key is skipped and another key in the list of keys is decrypted. + */ + @Test + void testMismatchedDataKeyException() { + EncryptedDataKey encryptedDataKey1 = new KeyBlob(KMS_PROVIDER_ID, "KeyId1".getBytes(PROVIDER_ENCODING), generate(64)); + EncryptedDataKey encryptedDataKey2 = new KeyBlob(KMS_PROVIDER_ID, "KeyId2".getBytes(PROVIDER_ENCODING), generate(64)); + SecretKey secretKey = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + + when(dataKeyEncryptionDao.decryptDataKey(encryptedDataKey1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenThrow(new MismatchedDataKeyException()); + when(dataKeyEncryptionDao.decryptDataKey(encryptedDataKey2, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DataKeyEncryptionDao.DecryptDataKeyResult("KeyId2", secretKey)); + + KmsMasterKey kmsMasterKey = new KmsMasterKey(dataKeyEncryptionDao, CMK_ARN, null); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(encryptedDataKey1); + encryptedDataKeys.add(encryptedDataKey2); + + DataKey result = kmsMasterKey.decryptDataKey(ALGORITHM_SUITE, encryptedDataKeys, ENCRYPTION_CONTEXT); + + assertEquals(secretKey, result.getKey()); + } + +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java new file mode 100644 index 000000000..e8bd05477 --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.services.kms.AWSKMS; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.junit.jupiter.api.Assertions.*; + +@ExtendWith(MockitoExtension.class) +class KmsUtilsTest { + + private static final String VALID_ARN = "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210"; + private static final String VALID_ALIAS_ARN = "arn:aws:kms:us-east-1:999999999999:alias/MyCryptoKey"; + private static final String VALID_ALIAS = "alias/MyCryptoKey"; + private static final String VALID_RAW_KEY_ID = "01234567-89ab-cdef-fedc-ba9876543210"; + + @Mock + private AWSKMS client; + + + @Test + void testGetClientByArn() { + assertEquals(client, KmsUtils.getClientByArn(VALID_ARN, s -> client)); + assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS_ARN, s -> client)); + assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS, s -> client)); + assertThrows(MalformedArnException.class, () -> KmsUtils.getClientByArn("arn:invalid", s -> client)); + assertEquals(client, KmsUtils.getClientByArn(VALID_RAW_KEY_ID, s -> client)); + } + + @Test + void testIsArnWellFormed() { + assertTrue(KmsUtils.isArnWellFormed(VALID_ARN)); + assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS_ARN)); + assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS)); + assertFalse(KmsUtils.isArnWellFormed(VALID_RAW_KEY_ID)); + assertFalse(KmsUtils.isArnWellFormed("arn:invalid")); + + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java index 37fe9cbff..00ce5c074 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java @@ -29,7 +29,7 @@ import com.amazonaws.ResponseMetadata; import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; -import com.amazonaws.services.kms.AWSKMSClient; +import com.amazonaws.services.kms.AbstractAWSKMS; import com.amazonaws.services.kms.model.CreateAliasRequest; import com.amazonaws.services.kms.model.CreateAliasResult; import com.amazonaws.services.kms.model.CreateGrantRequest; @@ -85,7 +85,7 @@ import com.amazonaws.services.kms.model.UpdateKeyDescriptionRequest; import com.amazonaws.services.kms.model.UpdateKeyDescriptionResult; -public class MockKMSClient extends AWSKMSClient { +public class MockKMSClient extends AbstractAWSKMS { private static final SecureRandom rnd = new SecureRandom(); private static final String ACCOUNT_ID = "01234567890"; private final Map results_ = new HashMap<>();