diff --git a/examples/src/keyrings/hierarchical_keyring.py b/examples/src/keyrings/hierarchical_keyring.py index aa87485f9..32a6cbf8b 100644 --- a/examples/src/keyrings/hierarchical_keyring.py +++ b/examples/src/keyrings/hierarchical_keyring.py @@ -1,6 +1,36 @@ # Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Example showing basic encryption and decryption of a value already in memory.""" +""" +This example sets up the Hierarchical Keyring, which establishes a key hierarchy where "branch" +keys are persisted in DynamoDb. These branch keys are used to protect your data keys, and these +branch keys are themselves protected by a KMS Key. + +Establishing a key hierarchy like this has two benefits: +First, by caching the branch key material, and only calling KMS to re-establish authentication +regularly according to your configured TTL, you limit how often you need to call KMS to protect +your data. This is a performance security tradeoff, where your authentication, audit, and logging +from KMS is no longer one-to-one with every encrypt or decrypt call. Additionally, KMS Cloudtrail +cannot be used to distinguish Encrypt and Decrypt calls, and you cannot restrict who has +Encryption rights from who has Decryption rights since they both ONLY need KMS:Decrypt. However, +the benefit is that you no longer have to make a network call to KMS for every encrypt or +decrypt. + +Second, this key hierarchy facilitates cryptographic isolation of a tenant's data in a +multi-tenant data store. Each tenant can have a unique Branch Key, that is only used to protect +the tenant's data. You can either statically configure a single branch key to ensure you are +restricting access to a single tenant, or you can implement an interface that selects the Branch +Key based on the Encryption Context. + +This example demonstrates configuring a Hierarchical Keyring with a Branch Key ID Supplier to +encrypt and decrypt data for two separate tenants. + +This example requires access to the DDB Table where you are storing the Branch Keys. This +table must be configured with the following primary key configuration: - Partition key is named +"partition_key" with type (S) - Sort key is named "sort_key" with type (S). + +This example also requires using a KMS Key. You need the following access on this key: - +GenerateDataKeyWithoutPlaintext - Decrypt +""" import sys import boto3 @@ -25,6 +55,7 @@ from .example_branch_key_id_supplier import ExampleBranchKeyIdSupplier +# TODO-MPL: Remove this as part of removing PYTHONPATH hacks. module_root_dir = '/'.join(__file__.split("/")[:-1]) sys.path.append(module_root_dir) diff --git a/examples/src/keyrings/required_encryption_context_cmm.py b/examples/src/keyrings/required_encryption_context_cmm.py new file mode 100644 index 000000000..e0c19697c --- /dev/null +++ b/examples/src/keyrings/required_encryption_context_cmm.py @@ -0,0 +1,158 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Demonstrate an encrypt/decrypt cycle using a Required Encryption Context CMM. +A required encryption context CMM asks for required keys in the encryption context field +on encrypt such that they will not be stored on the message, but WILL be included in the header signature. +On decrypt, the client MUST supply the key/value pair(s) that were not stored to successfully decrypt the message. +""" +import sys + +import boto3 +# Ignore missing MPL for pylint, but the MPL is required for this example +# noqa pylint: disable=import-error +from aws_cryptographic_materialproviders.mpl import AwsCryptographicMaterialProviders +from aws_cryptographic_materialproviders.mpl.config import MaterialProvidersConfig +from aws_cryptographic_materialproviders.mpl.models import ( + CreateAwsKmsKeyringInput, + CreateDefaultCryptographicMaterialsManagerInput, + CreateRequiredEncryptionContextCMMInput, +) +from aws_cryptographic_materialproviders.mpl.references import ICryptographicMaterialsManager, IKeyring +from typing import Dict, List + +import aws_encryption_sdk +from aws_encryption_sdk import CommitmentPolicy +from aws_encryption_sdk.exceptions import AWSEncryptionSDKClientError + +# TODO-MPL: Remove this as part of removing PYTHONPATH hacks +module_root_dir = '/'.join(__file__.split("/")[:-1]) + +sys.path.append(module_root_dir) + +EXAMPLE_DATA: bytes = b"Hello World" + + +def encrypt_and_decrypt_with_keyring( + kms_key_id: str +): + """Creates a hierarchical keyring using the provided resources, then encrypts and decrypts a string with it.""" + # 1. Instantiate the encryption SDK client. + # This builds the client with the REQUIRE_ENCRYPT_REQUIRE_DECRYPT commitment policy, + # which enforces that this client only encrypts using committing algorithm suites and enforces + # that this client will only decrypt encrypted messages that were created with a committing + # algorithm suite. + # This is the default commitment policy if you were to build the client as + # `client = aws_encryption_sdk.EncryptionSDKClient()`. + + client = aws_encryption_sdk.EncryptionSDKClient( + commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT + ) + + # 2. Create an encryption context. + # Most encrypted data should have an associated encryption context + # to protect integrity. This sample uses placeholder values. + # For more information see: + # blogs.aws.amazon.com/security/post/Tx2LZ6WBJJANTNW/How-to-Protect-the-Integrity-of-Your-Encrypted-Data-by-Using-AWS-Key-Management # noqa: E501 + encryption_context: Dict[str, str] = { + "key1": "value1", + "key2": "value2", + "requiredKey1": "requiredValue1", + "requiredKey2": "requiredValue2", + } + + # 3. Create list of required encryption context keys. + # This is a list of keys that must be present in the encryption context. + required_encryption_context_keys: List[str] = ["requiredKey1", "requiredKey2"] + + # 4. Create the AWS KMS keyring. + mpl: AwsCryptographicMaterialProviders = AwsCryptographicMaterialProviders( + config=MaterialProvidersConfig() + ) + keyring_input: CreateAwsKmsKeyringInput = CreateAwsKmsKeyringInput( + kms_key_id=kms_key_id, + kms_client=boto3.client('kms', region_name="us-west-2") + ) + kms_keyring: IKeyring = mpl.create_aws_kms_keyring(keyring_input) + + # 5. Create the required encryption context CMM. + underlying_cmm: ICryptographicMaterialsManager = \ + mpl.create_default_cryptographic_materials_manager( + CreateDefaultCryptographicMaterialsManagerInput( + keyring=kms_keyring + ) + ) + + required_ec_cmm: ICryptographicMaterialsManager = \ + mpl.create_required_encryption_context_cmm( + CreateRequiredEncryptionContextCMMInput( + required_encryption_context_keys=required_encryption_context_keys, + underlying_cmm=underlying_cmm, + ) + ) + + # 6. Encrypt the data + ciphertext, _ = client.encrypt( + source=EXAMPLE_DATA, + materials_manager=required_ec_cmm, + encryption_context=encryption_context + ) + + # 7. Reproduce the encryption context. + # The reproduced encryption context MUST contain a value for + # every key in the configured required encryption context keys during encryption with + # Required Encryption Context CMM. + reproduced_encryption_context: Dict[str, str] = { + "requiredKey1": "requiredValue1", + "requiredKey2": "requiredValue2", + } + + # 8. Decrypt the data + plaintext_bytes_A, _ = client.decrypt( + source=ciphertext, + materials_manager=required_ec_cmm, + encryption_context=reproduced_encryption_context + ) + assert plaintext_bytes_A == EXAMPLE_DATA + + # We can also decrypt using the underlying CMM, + # but must also provide the reproduced encryption context + plaintext_bytes_A, _ = client.decrypt( + source=ciphertext, + materials_manager=underlying_cmm, + encryption_context=reproduced_encryption_context + ) + assert plaintext_bytes_A == EXAMPLE_DATA + + # 9. Extra: Demonstrate that if we don't provide the reproduced encryption context, + # decryption will fail. + try: + plaintext_bytes_A, _ = client.decrypt( + source=ciphertext, + materials_manager=required_ec_cmm, + # No reproduced encryption context for required EC CMM-produced message makes decryption fail. + ) + raise Exception("If this exception is raised, decryption somehow succeeded!") + except AWSEncryptionSDKClientError: + # Swallow specific expected exception. + # We expect decryption to fail with an AWSEncryptionSDKClientError + # since we did not provide reproduced encryption context when decrypting + # a message encrypted with the requried encryption context CMM. + pass + + # Same for the default CMM; + # If we don't provide the reproduced encryption context, decryption will fail. + try: + plaintext_bytes_A, _ = client.decrypt( + source=ciphertext, + materials_manager=required_ec_cmm, + # No reproduced encryption context for required EC CMM-produced message makes decryption fail. + ) + raise Exception("If this exception is raised, decryption somehow succeeded!") + except AWSEncryptionSDKClientError: + # Swallow specific expected exception. + # We expect decryption to fail with an AWSEncryptionSDKClientError + # since we did not provide reproduced encryption context when decrypting + # a message encrypted with the requried encryption context CMM, + # even though we are using a default CMM on decrypt. + pass diff --git a/examples/test/keyrings/test_i_hierarchical_keyring.py b/examples/test/keyrings/test_i_hierarchical_keyring.py index 4cae478d7..c4583534a 100644 --- a/examples/test/keyrings/test_i_hierarchical_keyring.py +++ b/examples/test/keyrings/test_i_hierarchical_keyring.py @@ -1,6 +1,6 @@ # Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit test suite for the hierarchical keyring example.""" +"""Test suite for the hierarchical keyring example.""" import pytest from ...src.keyrings.hierarchical_keyring import encrypt_and_decrypt_with_keyring diff --git a/examples/test/keyrings/test_i_required_encryption_context_cmm.py b/examples/test/keyrings/test_i_required_encryption_context_cmm.py new file mode 100644 index 000000000..724705faa --- /dev/null +++ b/examples/test/keyrings/test_i_required_encryption_context_cmm.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Test suite for the required encryption context CMM example.""" +import pytest + +from ...src.keyrings.required_encryption_context_cmm import encrypt_and_decrypt_with_keyring + +pytestmark = [pytest.mark.examples] + + +def test_encrypt_and_decrypt_with_keyring(): + key_arn = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f" + encrypt_and_decrypt_with_keyring(key_arn) diff --git a/setup.py b/setup.py index 0615a43c7..697cb96ce 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def get_requirements(): license="Apache License 2.0", install_requires=get_requirements(), # pylint: disable=fixme - # TODO: Point at PyPI once MPL is released. + # TODO-MPL: Point at PyPI once MPL is released. # This blocks releasing ESDK-Python MPL integration. extras_require={ "MPL": ["aws-cryptographic-material-providers @" \ diff --git a/src/aws_encryption_sdk/__init__.py b/src/aws_encryption_sdk/__init__.py index 661d41ee6..74732bfc5 100644 --- a/src/aws_encryption_sdk/__init__.py +++ b/src/aws_encryption_sdk/__init__.py @@ -185,6 +185,9 @@ def decrypt(self, **kwargs): If source_length is not provided and read() is called, will attempt to seek() to the end of the stream and tell() to find the length of source data. + :param dict encryption_context: Dictionary defining encryption context to validate + on decrypt. This is ONLY validated on decrypt if using a CMM from the + aws-cryptographic-material-providers library. :param int max_body_length: Maximum frame size (or content length for non-framed messages) in bytes to read from ciphertext message. :returns: Tuple containing the decrypted plaintext and the message header object diff --git a/src/aws_encryption_sdk/internal/formatting/serialize.py b/src/aws_encryption_sdk/internal/formatting/serialize.py index b4d866099..9f1325f98 100644 --- a/src/aws_encryption_sdk/internal/formatting/serialize.py +++ b/src/aws_encryption_sdk/internal/formatting/serialize.py @@ -189,7 +189,13 @@ def serialize_header(header, signer=None): raise SerializationError("Unrecognized message format version: {}".format(header.version)) -def _serialize_header_auth_v1(algorithm, header, data_encryption_key, signer=None): +def _serialize_header_auth_v1( + algorithm, + header, + data_encryption_key, + signer=None, + required_ec_bytes=None +): """Creates serialized header authentication data for messages in serialization version V1. :param algorithm: Algorithm to use for encryption @@ -198,16 +204,35 @@ def _serialize_header_auth_v1(algorithm, header, data_encryption_key, signer=Non :param bytes data_encryption_key: Data key with which to encrypt message :param signer: Cryptographic signer object (optional) :type signer: aws_encryption_sdk.Signer + :param required_encryption_context_bytes: Serialized encryption context items + for all items whose keys are in the required_encryption_context list. + This is ONLY processed if using the aws-cryptographic-materialproviders library + AND its required encryption context CMM. (optional) + :type required_encryption_context_bytes: bytes :returns: Serialized header authentication data :rtype: bytes """ - header_auth = encrypt( - algorithm=algorithm, - key=data_encryption_key, - plaintext=b"", - associated_data=header, - iv=header_auth_iv(algorithm), - ) + if required_ec_bytes is None: + header_auth = encrypt( + algorithm=algorithm, + key=data_encryption_key, + plaintext=b"", + associated_data=header, + iv=header_auth_iv(algorithm), + ) + else: + header_auth = encrypt( + algorithm=algorithm, + key=data_encryption_key, + plaintext=b"", + # The AAD MUST be the concatenation of the serialized message header body and the serialization + # of encryption context to only authenticate. The encryption context to only authenticate MUST + # be the encryption context in the encryption materials filtered to only contain key value + # pairs listed in the encryption material's required encryption context keys serialized + # according to the encryption context serialization specification. + associated_data=header + required_ec_bytes, + iv=header_auth_iv(algorithm), + ) output = struct.pack( ">{iv_len}s{tag_len}s".format(iv_len=algorithm.iv_len, tag_len=algorithm.tag_len), header_auth.iv, @@ -218,7 +243,13 @@ def _serialize_header_auth_v1(algorithm, header, data_encryption_key, signer=Non return output -def _serialize_header_auth_v2(algorithm, header, data_encryption_key, signer=None): +def _serialize_header_auth_v2( + algorithm, + header, + data_encryption_key, + signer=None, + required_ec_bytes=None +): """Creates serialized header authentication data for messages in serialization version V2. :param algorithm: Algorithm to use for encryption @@ -227,16 +258,35 @@ def _serialize_header_auth_v2(algorithm, header, data_encryption_key, signer=Non :param bytes data_encryption_key: Data key with which to encrypt message :param signer: Cryptographic signer object (optional) :type signer: aws_encryption_sdk.Signer + :param required_encryption_context_bytes: Serialized encryption context items + for all items whose keys are in the required_encryption_context list. + This is ONLY processed if using the aws-cryptographic-materialproviders library + AND its required encryption context CMM. (optional) + :type required_encryption_context_bytes: bytes :returns: Serialized header authentication data :rtype: bytes """ - header_auth = encrypt( - algorithm=algorithm, - key=data_encryption_key, - plaintext=b"", - associated_data=header, - iv=header_auth_iv(algorithm), - ) + if required_ec_bytes is None: + header_auth = encrypt( + algorithm=algorithm, + key=data_encryption_key, + plaintext=b"", + associated_data=header, + iv=header_auth_iv(algorithm), + ) + else: + header_auth = encrypt( + algorithm=algorithm, + key=data_encryption_key, + plaintext=b"", + # The AAD MUST be the concatenation of the serialized message header body and the serialization + # of encryption context to only authenticate. The encryption context to only authenticate MUST + # be the encryption context in the encryption materials filtered to only contain key value + # pairs listed in the encryption material's required encryption context keys serialized + # according to the encryption context serialization specification. + associated_data=header + required_ec_bytes, + iv=header_auth_iv(algorithm), + ) output = struct.pack( ">{tag_len}s".format(tag_len=algorithm.tag_len), header_auth.tag, @@ -246,7 +296,14 @@ def _serialize_header_auth_v2(algorithm, header, data_encryption_key, signer=Non return output -def serialize_header_auth(version, algorithm, header, data_encryption_key, signer=None): +def serialize_header_auth( + version, + algorithm, + header, + data_encryption_key, + signer=None, + required_ec_bytes=None +): """Creates serialized header authentication data. :param version: The serialization version of the message @@ -257,13 +314,22 @@ def serialize_header_auth(version, algorithm, header, data_encryption_key, signe :param bytes data_encryption_key: Data key with which to encrypt message :param signer: Cryptographic signer object (optional) :type signer: aws_encryption_sdk.Signer + :param required_encryption_context_bytes: Serialized encryption context items + for all items whose keys are in the required_encryption_context list. + This is ONLY processed if using the aws-cryptographic-materialproviders library + AND its required encryption context CMM. (optional) + :type required_encryption_context_bytes: bytes :returns: Serialized header authentication data :rtype: bytes """ if version == SerializationVersion.V1: - return _serialize_header_auth_v1(algorithm, header, data_encryption_key, signer) + return _serialize_header_auth_v1( + algorithm, header, data_encryption_key, signer, required_ec_bytes + ) elif version == SerializationVersion.V2: - return _serialize_header_auth_v2(algorithm, header, data_encryption_key, signer) + return _serialize_header_auth_v2( + algorithm, header, data_encryption_key, signer, required_ec_bytes + ) else: raise SerializationError("Unrecognized message format version: {}".format(version)) diff --git a/src/aws_encryption_sdk/materials_managers/__init__.py b/src/aws_encryption_sdk/materials_managers/__init__.py index 9db1dafae..950dd87cd 100644 --- a/src/aws_encryption_sdk/materials_managers/__init__.py +++ b/src/aws_encryption_sdk/materials_managers/__init__.py @@ -89,11 +89,17 @@ class DecryptionMaterialsRequest(object): :param encrypted_data_keys: Set of encrypted data keys :type encrypted_data_keys: set of `aws_encryption_sdk.structures.EncryptedDataKey` :param dict encryption_context: Encryption context to provide to master keys for underlying decrypt requests + :param dict reproduced_encryption_context: Encryption context to provide on decrypt. + This is ONLY processed if using a CMM from the aws-cryptographic-materialproviders library. """ algorithm = attr.ib(validator=attr.validators.instance_of(Algorithm)) encrypted_data_keys = attr.ib(validator=attr.validators.instance_of(set)) encryption_context = attr.ib(validator=attr.validators.instance_of(dict)) + reproduced_encryption_context = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(dict)) + ) commitment_policy = attr.ib( default=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, validator=attr.validators.optional(attr.validators.instance_of(CommitmentPolicy)), diff --git a/src/aws_encryption_sdk/materials_managers/mpl/cmm.py b/src/aws_encryption_sdk/materials_managers/mpl/cmm.py index 880e37203..1f2102757 100644 --- a/src/aws_encryption_sdk/materials_managers/mpl/cmm.py +++ b/src/aws_encryption_sdk/materials_managers/mpl/cmm.py @@ -143,5 +143,6 @@ def _create_mpl_decrypt_materials_input_from_request( ), encrypted_data_keys=list_edks, encryption_context=request.encryption_context, + reproduced_encryption_context=request.reproduced_encryption_context, ) return output diff --git a/src/aws_encryption_sdk/materials_managers/mpl/materials.py b/src/aws_encryption_sdk/materials_managers/mpl/materials.py index dfd1bd6fc..54ea21b39 100644 --- a/src/aws_encryption_sdk/materials_managers/mpl/materials.py +++ b/src/aws_encryption_sdk/materials_managers/mpl/materials.py @@ -96,6 +96,13 @@ def signing_key(self) -> bytes: """Materials' signing key.""" return self.mpl_materials.signing_key + @property + # Pylint thinks this name is too long, but it's the best descriptor for this... + # pylint: disable=invalid-name + def required_encryption_context_keys(self) -> bytes: + """Materials' required encryption context keys.""" + return self.mpl_materials.required_encryption_context_keys + class DecryptionMaterialsFromMPL(Native_DecryptionMaterials): """ @@ -136,3 +143,15 @@ def data_key(self) -> DataKey: def verification_key(self) -> bytes: """Materials' verification key.""" return self.mpl_materials.verification_key + + @property + def encryption_context(self) -> Dict[str, str]: + """Materials' encryption context.""" + return self.mpl_materials.encryption_context + + @property + # Pylint thinks this name is too long, but it's the best descriptor for this... + # pylint: disable=invalid-name + def required_encryption_context_keys(self) -> bytes: + """Materials' required encryption context keys.""" + return self.mpl_materials.required_encryption_context_keys diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 5bf953244..a86112610 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -77,7 +77,10 @@ from aws_cryptographic_materialproviders.mpl.config import MaterialProvidersConfig from aws_cryptographic_materialproviders.mpl.errors import AwsCryptographicMaterialProvidersException from aws_cryptographic_materialproviders.mpl.models import CreateDefaultCryptographicMaterialsManagerInput - from aws_cryptographic_materialproviders.mpl.references import IKeyring + from aws_cryptographic_materialproviders.mpl.references import ( + ICryptographicMaterialsManager as MPL_ICryptographicMaterialsManager, + IKeyring as MPL_IKeyring, + ) _HAS_MPL = True # Import internal ESDK modules that depend on the MPL @@ -126,16 +129,37 @@ class _ClientConfig(object): # pylint: disable=too-many-instance-attributes max_encrypted_data_keys = attr.ib( hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(int)) ) - materials_manager = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(CryptoMaterialsManager)) - ) + if _HAS_MPL: + # With the MPL, the provided materials_manager can be an instance of + # either the native interface or an MPL interface. + # If it implements the MPL interface, this constructor will + # internally wrap it in a native interface. + materials_manager = attr.ib( + hash=True, + default=None, + validator=attr.validators.optional( + attr.validators.instance_of( + (CryptoMaterialsManager, MPL_ICryptographicMaterialsManager) + ) + ) + ) + else: + materials_manager = attr.ib( + hash=True, + default=None, + validator=attr.validators.optional( + attr.validators.instance_of( + CryptoMaterialsManager + ) + ) + ) key_provider = attr.ib( hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(MasterKeyProvider)) ) if _HAS_MPL: # Keyrings are only available if the MPL is installed in the runtime keyring = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(IKeyring)) + hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(MPL_IKeyring)) ) source_length = attr.ib( hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) @@ -173,6 +197,12 @@ def _has_mpl_attrs_post_init(self): # so customers only have to catch ESDK error types. raise AWSEncryptionSDKClientError(mpl_exception) + # If the provided materials_manager is directly from the MPL, wrap it in a native interface + # for internal use. + elif (self.materials_manager is not None + and isinstance(self.materials_manager, MPL_ICryptographicMaterialsManager)): + self.materials_manager = CryptoMaterialsManagerFromMPL(self.materials_manager) + def _no_mpl_attrs_post_init(self): """If the MPL is NOT present in the runtime, perform post-init logic to validate the new object has a valid state. @@ -593,11 +623,27 @@ def generate_header(self, message_id): if self._encryption_materials.algorithm.message_format_version == 0x02: version = SerializationVersion.V2 + # If the underlying materials_provider provided required_encryption_context_keys + # (ex. if the materials_provider is a required encryption context CMM), + # then partition the encryption context based on those keys. + if hasattr(self._encryption_materials, "required_encryption_context_keys"): + self._required_encryption_context = {} + self._stored_encryption_context = {} + for (key, value) in self._encryption_materials.encryption_context.items(): + if key in self._encryption_materials.required_encryption_context_keys: + self._required_encryption_context[key] = value + else: + self._stored_encryption_context[key] = value + # Otherwise, store all encryption context with the message. + else: + self._stored_encryption_context = self._encryption_materials.encryption_context + self._required_encryption_context = None + kwargs = dict( version=version, algorithm=self._encryption_materials.algorithm, message_id=message_id, - encryption_context=self._encryption_materials.encryption_context, + encryption_context=self._stored_encryption_context, encrypted_data_keys=self._encryption_materials.encrypted_data_keys, content_type=self.content_type, frame_length=self.config.frame_length, @@ -621,13 +667,31 @@ def generate_header(self, message_id): def _write_header(self): """Builds the message header and writes it to the output stream.""" self.output_buffer += serialize_header(header=self._header, signer=self.signer) - self.output_buffer += serialize_header_auth( - version=self._header.version, - algorithm=self._encryption_materials.algorithm, - header=self.output_buffer, - data_encryption_key=self._derived_data_key, - signer=self.signer, - ) + + # If there is _required_encryption_context, + # serialize it, then authenticate it + if hasattr(self, "_required_encryption_context"): + required_ec_serialized = \ + aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context( + self._required_encryption_context + ) + self.output_buffer += serialize_header_auth( + version=self._header.version, + algorithm=self._encryption_materials.algorithm, + header=self.output_buffer, + data_encryption_key=self._derived_data_key, + signer=self.signer, + required_ec_bytes=required_ec_serialized, + ) + # Otherwise, do not pass in any required encryption context + else: + self.output_buffer += serialize_header_auth( + version=self._header.version, + algorithm=self._encryption_materials.algorithm, + header=self.output_buffer, + data_encryption_key=self._derived_data_key, + signer=self.signer, + ) def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" @@ -825,11 +889,19 @@ class DecryptorConfig(_ClientConfig): :param int max_body_length: Maximum frame size (or content length for non-framed messages) in bytes to read from ciphertext message. + :param dict encryption_context: Dictionary defining encryption context to validate + on decrypt. This is ONLY validated on decrypt if using a CMM + from the aws-cryptographic-material-providers library. """ max_body_length = attr.ib( hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) ) + encryption_context = attr.ib( + hash=False, # dictionaries are not hashable + default=None, + validator=attr.validators.optional(attr.validators.instance_of(dict)), + ) class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-attributes @@ -882,6 +954,77 @@ def _prep_message(self): self._prep_non_framed() self._message_prepped = True + def _create_decrypt_materials_request(self, header): + """ + Create a DecryptionMaterialsRequest based on whether + the StreamDecryptor was provided encryption_context on decrypt + (i.e. expects to use a CMM from the MPL). + """ + # If encryption_context is provided on decrypt, + # pass it to the DecryptionMaterialsRequest as reproduced_encryption_context + if hasattr(self.config, "encryption_context") \ + and self.config.encryption_context is not None: + if (_HAS_MPL + and isinstance(self.config.materials_manager, CryptoMaterialsManagerFromMPL)): + return DecryptionMaterialsRequest( + encrypted_data_keys=header.encrypted_data_keys, + algorithm=header.algorithm, + encryption_context=header.encryption_context, + commitment_policy=self.config.commitment_policy, + reproduced_encryption_context=self.config.encryption_context + ) + else: + raise TypeError("encryption_context on decrypt is only supported for CMMs and keyrings " + "from the aws-cryptographic-material-providers library.") + return DecryptionMaterialsRequest( + encrypted_data_keys=header.encrypted_data_keys, + algorithm=header.algorithm, + encryption_context=header.encryption_context, + commitment_policy=self.config.commitment_policy, + ) + + def _validate_parsed_header( + self, + header, + header_auth, + raw_header, + ): + """ + Pass arguments from this StreamDecryptor to validate_header based on whether + the StreamDecryptor has the _required_encryption_context attribute + (i.e. is using the required encryption context CMM from the MPL). + """ + # If _required_encryption_context is present, + # serialize it and pass it to validate_header. + if hasattr(self, "_required_encryption_context") \ + and self._required_encryption_context is not None: + # The authenticated only encryption context is all encryption context key-value pairs where the + # key exists in Required Encryption Context Keys. It is then serialized according to the + # message header Key Value Pairs. + required_ec_serialized = \ + aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context( + self._required_encryption_context + ) + + validate_header( + header=header, + header_auth=header_auth, + # When verifying the header, the AAD input to the authenticated encryption algorithm + # specified by the algorithm suite is the message header body and the serialized + # authenticated only encryption context. + raw_header=raw_header + required_ec_serialized, + data_key=self._derived_data_key + ) + else: + validate_header( + header=header, + header_auth=header_auth, + raw_header=raw_header, + data_key=self._derived_data_key + ) + + return header, header_auth + def _read_header(self): """Reads the message header from the input stream. @@ -908,13 +1051,20 @@ def _read_header(self): ) ) - decrypt_materials_request = DecryptionMaterialsRequest( - encrypted_data_keys=header.encrypted_data_keys, - algorithm=header.algorithm, - encryption_context=header.encryption_context, - commitment_policy=self.config.commitment_policy, - ) + decrypt_materials_request = self._create_decrypt_materials_request(header) decryption_materials = self.config.materials_manager.decrypt_materials(request=decrypt_materials_request) + + # If the materials_manager passed required_encryption_context_keys, + # get the items out of the encryption_context with the keys. + # The items are used in header validation. + if hasattr(decryption_materials, "required_encryption_context_keys"): + self._required_encryption_context = {} + for (key, value) in decryption_materials.encryption_context.items(): + if key in decryption_materials.required_encryption_context_keys: + self._required_encryption_context[key] = value + else: + self._required_encryption_context = None + if decryption_materials.verification_key is None: self.verifier = None else: @@ -953,9 +1103,11 @@ def _read_header(self): "message. Halting processing of this message." ) - validate_header(header=header, header_auth=header_auth, raw_header=raw_header, data_key=self._derived_data_key) - - return header, header_auth + return self._validate_parsed_header( + header=header, + header_auth=header_auth, + raw_header=raw_header, + ) def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" diff --git a/test/mpl/unit/test_material_managers_mpl_cmm.py b/test/mpl/unit/test_material_managers_mpl_cmm.py index 80d6f00ee..0551e8f30 100644 --- a/test/mpl/unit/test_material_managers_mpl_cmm.py +++ b/test/mpl/unit/test_material_managers_mpl_cmm.py @@ -38,6 +38,7 @@ mock_mpl_cmm = MagicMock(__class__=MPL_ICryptographicMaterialsManager) mock_mpl_encryption_materials = MagicMock(__class__=MPL_EncryptionMaterials) mock_mpl_decrypt_materials = MagicMock(__class__=MPL_DecryptionMaterials) +mock_reproduced_encryption_context = MagicMock(__class_=dict) mock_edk = MagicMock(__class__=Native_EncryptedDataKey) @@ -259,6 +260,7 @@ def test_GIVEN_valid_request_WHEN_create_mpl_decrypt_materials_input_from_reques for mock_edks in [no_mock_edks, one_mock_edk, two_mock_edks]: mock_decryption_materials_request.encrypted_data_keys = mock_edks + mock_decryption_materials_request.reproduced_encryption_context = mock_reproduced_encryption_context # When: _create_mpl_decrypt_materials_input_from_request output = CryptoMaterialsManagerFromMPL._create_mpl_decrypt_materials_input_from_request( @@ -271,6 +273,7 @@ def test_GIVEN_valid_request_WHEN_create_mpl_decrypt_materials_input_from_reques assert output.algorithm_suite_id == mock_algorithm_id assert output.commitment_policy == mock_commitment_policy assert output.encryption_context == mock_decryption_materials_request.encryption_context + assert output.reproduced_encryption_context == mock_reproduced_encryption_context assert len(output.encrypted_data_keys) == len(mock_edks) for i in range(len(output.encrypted_data_keys)): diff --git a/test/mpl/unit/test_material_managers_mpl_materials.py b/test/mpl/unit/test_material_managers_mpl_materials.py index 9e76556a2..8d9052c0a 100644 --- a/test/mpl/unit/test_material_managers_mpl_materials.py +++ b/test/mpl/unit/test_material_managers_mpl_materials.py @@ -160,6 +160,19 @@ def test_GIVEN_valid_signing_key_WHEN_EncryptionMaterials_get_signing_key_THEN_r assert output == mock_signing_key +def test_GIVEN_valid_required_encryption_context_keys_WHEN_EncryptionMaterials_get_required_encryption_context_keys_THEN_returns_required_encryption_context_keys(): # noqa pylint: disable=line-too-long + # Given: valid required encryption context keys + mock_required_encryption_context_keys = MagicMock(__class__=bytes) + mock_mpl_encryption_materials.required_encryption_context_keys = mock_required_encryption_context_keys + + # When: get required encryption context keys + mpl_encryption_materials = EncryptionMaterialsFromMPL(mpl_materials=mock_mpl_encryption_materials) + output = mpl_encryption_materials.required_encryption_context_keys + + # Then: returns required encryption context keys + assert output == mock_required_encryption_context_keys + + def test_GIVEN_valid_data_key_WHEN_DecryptionMaterials_get_data_key_THEN_returns_data_key(): # Given: valid MPL data key mock_data_key = MagicMock(__class__=bytes) @@ -187,3 +200,29 @@ def test_GIVEN_valid_verification_key_WHEN_DecryptionMaterials_get_verification_ # Then: returns verification key assert output == mock_verification_key + + +def test_GIVEN_valid_encryption_context_WHEN_DecryptionMaterials_get_encryption_context_THEN_returns_encryption_context(): # noqa pylint: disable=line-too-long + # Given: valid encryption context + mock_encryption_context = MagicMock(__class__=Dict[str, str]) + mock_mpl_decrypt_materials.encryption_context = mock_encryption_context + + # When: get encryption context + mpl_decryption_materials = DecryptionMaterialsFromMPL(mpl_materials=mock_mpl_decrypt_materials) + output = mpl_decryption_materials.encryption_context + + # Then: returns valid encryption context + assert output == mock_encryption_context + + +def test_GIVEN_valid_required_encryption_context_keys_WHEN_DecryptionMaterials_get_required_encryption_context_keys_THEN_returns_required_encryption_context_keys(): # noqa pylint: disable=line-too-long + # Given: valid required encryption context keys + mock_required_encryption_context_keys = MagicMock(__class__=bytes) + mock_mpl_decrypt_materials.required_encryption_context_keys = mock_required_encryption_context_keys + + # When: get required encryption context keys + mpl_decryption_materials = DecryptionMaterialsFromMPL(mpl_materials=mock_mpl_decrypt_materials) + output = mpl_decryption_materials.required_encryption_context_keys + + # Then: returns required encryption context keys + assert output == mock_required_encryption_context_keys diff --git a/test/unit/test_serialize.py b/test/unit/test_serialize.py index 56da114b4..06ac6126b 100644 --- a/test/unit/test_serialize.py +++ b/test/unit/test_serialize.py @@ -79,6 +79,7 @@ def apply_fixtures(self): "aws_encryption_sdk.internal.formatting.serialize.aws_encryption_sdk.internal.utils.validate_frame_length" ) self.mock_valid_frame_length = self.mock_valid_frame_length_patcher.start() + self.mock_required_ec_bytes = MagicMock() # Set up mock signer self.mock_signer = MagicMock() self.mock_signer.update.return_value = None @@ -167,6 +168,34 @@ def test_serialize_header_auth_v1_no_signer(self): data_encryption_key=VALUES["data_key_obj"], ) + @patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv") + def test_GIVEN_required_ec_bytes_WHEN_serialize_header_auth_v1_THEN_aad_has_required_ec_bytes( + self, + mock_header_auth_iv, + ): + """Validate that the _create_header_auth function + behaves as expected for SerializationVersion.V1 + when required_ec_bytes are provided. + """ + self.mock_encrypt.return_value = VALUES["header_auth_base"] + test = aws_encryption_sdk.internal.formatting.serialize.serialize_header_auth( + version=SerializationVersion.V1, + algorithm=self.mock_algorithm, + header=VALUES["serialized_header"], + data_encryption_key=sentinel.encryption_key, + signer=self.mock_signer, + required_ec_bytes=self.mock_required_ec_bytes, + ) + self.mock_encrypt.assert_called_once_with( + algorithm=self.mock_algorithm, + key=sentinel.encryption_key, + plaintext=b"", + associated_data=VALUES["serialized_header"] + self.mock_required_ec_bytes, + iv=mock_header_auth_iv.return_value, + ) + self.mock_signer.update.assert_called_once_with(VALUES["serialized_header_auth"]) + assert test == VALUES["serialized_header_auth"] + @patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv") def test_serialize_header_auth_v2(self, mock_header_auth_iv): """Validate that the _create_header_auth function @@ -203,6 +232,33 @@ def test_serialize_header_auth_v2_no_signer(self): data_encryption_key=VALUES["data_key_obj"], ) + @patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv") + def test_GIVEN_required_ec_bytes_WHEN_serialize_header_auth_v2_THEN_aad_has_required_ec_bytes( + self, + mock_header_auth_iv, + ): + """Validate that the _create_header_auth function + behaves as expected for SerializationVersion.V2. + """ + self.mock_encrypt.return_value = VALUES["header_auth_base"] + test = aws_encryption_sdk.internal.formatting.serialize.serialize_header_auth( + version=SerializationVersion.V2, + algorithm=self.mock_algorithm, + header=VALUES["serialized_header_v2_committing"], + data_encryption_key=sentinel.encryption_key, + signer=self.mock_signer, + required_ec_bytes=self.mock_required_ec_bytes, + ) + self.mock_encrypt.assert_called_once_with( + algorithm=self.mock_algorithm, + key=sentinel.encryption_key, + plaintext=b"", + associated_data=VALUES["serialized_header_v2_committing"] + self.mock_required_ec_bytes, + iv=mock_header_auth_iv.return_value, + ) + self.mock_signer.update.assert_called_once_with(VALUES["serialized_header_auth_v2"]) + assert test == VALUES["serialized_header_auth_v2"] + def test_serialize_non_framed_open(self): """Validate that the serialize_non_framed_open function behaves as expected. diff --git a/test/unit/test_streaming_client_configs.py b/test/unit/test_streaming_client_configs.py index 18886f65b..435aff0da 100644 --- a/test/unit/test_streaming_client_configs.py +++ b/test/unit/test_streaming_client_configs.py @@ -15,7 +15,7 @@ import pytest import six -from mock import patch +from mock import MagicMock, patch from aws_encryption_sdk import CommitmentPolicy from aws_encryption_sdk.internal.defaults import ALGORITHM, FRAME_LENGTH, LINE_LENGTH @@ -33,7 +33,7 @@ # Ideally, this logic would be based on mocking imports and testing logic, # but doing that introduces errors that cause other tests to fail. try: - from aws_cryptographic_materialproviders.mpl.references import IKeyring + from aws_cryptographic_materialproviders.mpl.references import ICryptographicMaterialsManager, IKeyring HAS_MPL = True from aws_encryption_sdk.materials_managers.mpl.cmm import CryptoMaterialsManagerFromMPL @@ -236,24 +236,28 @@ def test_client_configs_with_mpl( assert test.materials_manager is not None # If materials manager was provided, it should be directly used - if hasattr(kwargs, "materials_manager"): + if "materials_manager" in kwargs: assert kwargs["materials_manager"] == test.materials_manager + # If native key_provider was provided, it should be wrapped in native materials manager + elif "key_provider" in kwargs: + assert test.key_provider is not None + assert test.key_provider == kwargs["key_provider"] + assert isinstance(test.materials_manager, DefaultCryptoMaterialsManager) + # If MPL keyring was provided, it should be wrapped in MPL materials manager - if hasattr(kwargs, "keyring"): + elif "keyring" in kwargs: assert test.keyring is not None assert test.keyring == kwargs["keyring"] assert isinstance(test.keyring, IKeyring) assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL) - # If native key_provider was provided, it should be wrapped in native materials manager - if hasattr(kwargs, "key_provider"): - assert test.key_provider is not None - assert test.key_provider == kwargs["key_provider"] - assert isinstance(test.materials_manager, DefaultCryptoMaterialsManager) + else: + raise ValueError(f"Test did not find materials_manager or key_provider. {kwargs}") -# This needs its own test; pytest parametrize cannot use a conditionally-loaded type +# This is an addition to test_client_configs_with_mpl; +# This needs its own test; pytest's parametrize cannot use a conditionally-loaded type (IKeyring) @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") def test_keyring_client_config_with_mpl( ): @@ -265,16 +269,30 @@ def test_keyring_client_config_with_mpl( test = _ClientConfig(**kwargs) - # In all cases, config should have a materials manager assert test.materials_manager is not None - # If materials manager was provided, it should be directly used - if hasattr(kwargs, "materials_manager"): - assert kwargs["materials_manager"] == test.materials_manager + assert test.keyring is not None + assert test.keyring == kwargs["keyring"] + assert isinstance(test.keyring, IKeyring) + assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL) - # If MPL keyring was provided, it should be wrapped in MPL materials manager - if hasattr(kwargs, "keyring"): - assert test.keyring is not None - assert test.keyring == kwargs["keyring"] - assert isinstance(test.keyring, IKeyring) - assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL) + +# This is an addition to test_client_configs_with_mpl; +# This needs its own test; pytest's parametrize cannot use a conditionally-loaded type (MPL CMM) +@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") +def test_mpl_cmm_client_config_with_mpl( +): + mock_mpl_cmm = MagicMock(__class__=ICryptographicMaterialsManager) + kwargs = { + "source": b"", + "materials_manager": mock_mpl_cmm, + "commitment_policy": CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT + } + + test = _ClientConfig(**kwargs) + + assert test.materials_manager is not None + # Assert that the MPL CMM is wrapped in the native interface + assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL) + # Assert the MPL CMM is used by the native interface + assert test.materials_manager.mpl_cmm == mock_mpl_cmm diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index e06cad308..ce3d6ee3c 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit test suite for aws_encryption_sdk.streaming_client.StreamDecryptor""" +# noqa pylint: disable=too-many-lines import io import pytest @@ -193,6 +194,10 @@ def test_read_header(self, mock_derive_datakey, mock_decrypt_materials_request, test_decryptor.source_stream = ct_stream test_decryptor._stream_length = len(VALUES["data_128"]) + # Mock: hasattr(self.config, "encryption_context") returns False + if hasattr(test_decryptor.config, "encryption_context"): + del test_decryptor.config.encryption_context + test_header, test_header_auth = test_decryptor._read_header() self.mock_deserialize_header.assert_called_once_with(ct_stream, None) @@ -293,6 +298,39 @@ def test_GIVEN_verification_key_AND_has_mpl_AND_not_MPLCMM_WHEN_read_header_THEN algorithm=self.mock_header.algorithm, key_bytes=sentinel.verification_key ) + @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") + @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") + @patch("aws_encryption_sdk.streaming_client.Verifier") + # Given: no MPL + @pytest.mark.skipif(HAS_MPL, reason="Test should only be executed without MPL in installation") + def test_GIVEN_decrypt_config_has_ec_AND_no_mpl_WHEN_read_header_THEN_raise_TypeError( + self, + mock_verifier, + mock_decrypt_materials_request, + *_, + ): + + mock_verifier_instance = MagicMock() + mock_verifier.from_key_bytes.return_value = mock_verifier_instance + ct_stream = io.BytesIO(VALUES["data_128"]) + mock_commitment_policy = MagicMock(__class__=CommitmentPolicy) + test_decryptor = StreamDecryptor( + materials_manager=self.mock_materials_manager, + source=ct_stream, + commitment_policy=mock_commitment_policy, + ) + test_decryptor.source_stream = ct_stream + test_decryptor._stream_length = len(VALUES["data_128"]) + # Given: self.config has "encryption_context" + # (i.e. encryption context provided on decrypt) + any_reproduced_ec = {"some": "ec"} + test_decryptor.config.encryption_context = any_reproduced_ec + + # Then: raise TypeError + with pytest.raises(TypeError): + # When: read header + test_decryptor._read_header() + @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") @patch("aws_encryption_sdk.streaming_client.Verifier") @@ -882,3 +920,219 @@ def test_close_no_footer(self, mock_close): with pytest.raises(SerializationError) as excinfo: test_decryptor.close() excinfo.match("Footer not read") + + @patch("aws_encryption_sdk.streaming_client.validate_header") + def test_GIVEN_does_not_have_required_EC_WHEN_validate_parsed_header_THEN_validate_header( + self, + mock_validate_header + ): + self.mock_header.content_type = ContentType.FRAMED_DATA + test_decryptor = StreamDecryptor( + materials_manager=self.mock_materials_manager, + source=self.mock_input_stream, + commitment_policy=self.mock_commitment_policy, + ) + test_decryptor._derived_data_key = sentinel.derived_data_key + # Given: test_decryptor does not have _required_encryption_context attribute + # When: _validate_parsed_header + test_decryptor._validate_parsed_header( + header=self.mock_header, + header_auth=sentinel.header_auth, + raw_header=self.mock_raw_header + ) + # Then: validate_header + mock_validate_header.assert_called_once_with( + header=self.mock_header, + header_auth=sentinel.header_auth, + raw_header=self.mock_raw_header, + data_key=sentinel.derived_data_key, + ) + + @patch("aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context") + @patch("aws_encryption_sdk.streaming_client.validate_header") + def test_GIVEN_has_required_EC_WHEN_validate_parsed_header_THEN_validate_header_with_serialized_required_EC( + self, + mock_validate_header, + mock_serialize_encryption_context, + ): + self.mock_header.content_type = ContentType.FRAMED_DATA + test_decryptor = StreamDecryptor( + materials_manager=self.mock_materials_manager, + source=self.mock_input_stream, + commitment_policy=self.mock_commitment_policy, + ) + test_decryptor._derived_data_key = sentinel.derived_data_key + # Given: test_decryptor has _required_encryption_context attribute + mock_required_ec = MagicMock(__class__=dict) + test_decryptor._required_encryption_context = mock_required_ec + mock_serialized_required_ec = MagicMock(__class__=bytes) + mock_serialize_encryption_context.return_value = mock_serialized_required_ec + # When: _validate_parsed_header + test_decryptor._validate_parsed_header( + header=self.mock_header, + header_auth=sentinel.header_auth, + raw_header=self.mock_raw_header + ) + # Then: call validate_header with serialized required EC + mock_validate_header.assert_called_once_with( + header=self.mock_header, + header_auth=sentinel.header_auth, + raw_header=self.mock_raw_header + mock_serialized_required_ec, + data_key=sentinel.derived_data_key, + ) + + # Given: has MPL + @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") + def test_GIVEN_has_MPL_AND_config_has_EC_WHEN_create_decrypt_materials_request_THEN_provide_reproduced_EC( + self, + ): + self.mock_header.content_type = ContentType.FRAMED_DATA + test_decryptor = StreamDecryptor( + materials_manager=self.mock_mpl_materials_manager, + source=self.mock_input_stream, + commitment_policy=self.mock_commitment_policy, + ) + + # Given: StreamDecryptor.config has encryption_context attribute + mock_reproduced_encryption_context = MagicMock(__class__=dict) + test_decryptor.config.encryption_context = mock_reproduced_encryption_context + # Type checking on header encryption context seems to require concrete instance, + # neither MagicMock nor sentinel value work + self.mock_header.encryption_context = {"some_key_to_pass_type_validation": "some_value"} + + # When: _create_decrypt_materials_request + output = test_decryptor._create_decrypt_materials_request( + header=self.mock_header, + ) + + # Then: decrypt_materials_request has reproduced_encryption_context attribute + assert hasattr(output, "reproduced_encryption_context") + assert output.reproduced_encryption_context == mock_reproduced_encryption_context + + def test_GIVEN_config_does_not_have_EC_WHEN_create_decrypt_materials_request_THEN_request_does_not_have_reproduced_EC( # noqa pylint: disable=line-too-long + self, + ): + self.mock_header.content_type = ContentType.FRAMED_DATA + test_decryptor = StreamDecryptor( + materials_manager=self.mock_materials_manager, + source=self.mock_input_stream, + commitment_policy=self.mock_commitment_policy, + ) + + # Given: StreamDecryptor.config does not have an encryption_context attribute + del test_decryptor.config.encryption_context + # Type checking on header encryption context seems to require concrete instance, + # neither MagicMock nor sentinel value work + self.mock_header.encryption_context = {"some_key_to_pass_type_validation": "some_value"} + + # When: _create_decrypt_materials_request + output = test_decryptor._create_decrypt_materials_request( + header=self.mock_header, + ) + + # Then: decrypt_materials_request.reproduced_encryption_context is None + assert output.reproduced_encryption_context is None + + @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") + @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") + @patch("aws_encryption_sdk.streaming_client.Verifier") + @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") + def test_GIVEN_materials_has_no_required_encryption_context_keys_attr_WHEN_read_header_THEN_required_EC_is_None( + self, + mock_verifier, + *_ + ): + + mock_verifier_instance = MagicMock() + mock_verifier.from_key_bytes.return_value = mock_verifier_instance + + self.mock_header.content_type = ContentType.FRAMED_DATA + test_decryptor = StreamDecryptor( + materials_manager=self.mock_materials_manager, + source=self.mock_input_stream, + commitment_policy=self.mock_commitment_policy, + ) + + # Given: decryption_materials does not have a required_encryption_context_keys attribute + del self.mock_decrypt_materials.required_encryption_context_keys + + # When: _read_header + test_decryptor._read_header() + + # Then: StreamDecryptor._required_encryption_context is None + assert test_decryptor._required_encryption_context is None + + @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") + @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") + @patch("aws_encryption_sdk.streaming_client.Verifier") + # Given: has MPL + @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") + def test_GIVEN_materials_has_required_encryption_context_keys_attr_WHEN_read_header_THEN_creates_correct_required_EC( # noqa pylint: disable=line-too-long + self, + mock_verifier, + *_ + ): + required_encryption_context_keys_values = [ + # Case of empty encryption context list is not allowed; + # if a list is provided, it must be non-empty. + # The MPL enforces this behavior on construction. + ["one_key"], + ["one_key", "two_key"], + ["one_key", "two_key", "red_key"], + ["one_key", "two_key", "red_key", "blue_key"], + ] + + encryption_context_values = [ + {}, + {"one_key": "some_value"}, + { + "one_key": "some_value", + "two_key": "some_other_value", + }, + { + "one_key": "some_value", + "two_key": "some_other_value", + "red_key": "some_red_value", + }, + { + "one_key": "some_value", + "two_key": "some_other_value", + "red_key": "some_red_value", + "blue_key": "some_blue_value", + } + ] + + for required_encryption_context_keys in required_encryption_context_keys_values: + + # Given: decryption_materials has required_encryption_context_keys + self.mock_decrypt_materials.required_encryption_context_keys = \ + required_encryption_context_keys + + for encryption_context in encryption_context_values: + + self.mock_decrypt_materials.encryption_context = encryption_context + + mock_verifier_instance = MagicMock() + mock_verifier.from_key_bytes.return_value = mock_verifier_instance + + self.mock_header.content_type = ContentType.FRAMED_DATA + test_decryptor = StreamDecryptor( + materials_manager=self.mock_materials_manager, + source=self.mock_input_stream, + commitment_policy=self.mock_commitment_policy, + ) + + # When: _read_header + test_decryptor._read_header() + + # Then: Assert correctness of partitioned EC + for k, v in encryption_context.items(): + # If a key is in required_encryption_context_keys, then ... + if k in required_encryption_context_keys: + # ... its EC is in the StreamEncryptor._required_encryption_context + assert k in test_decryptor._required_encryption_context + assert test_decryptor._required_encryption_context[k] == v + # If a key is NOT in required_encryption_context_keys, then ... + else: + # ... its EC is NOT in the StreamEncryptor._required_encryption_context + assert k not in test_decryptor._required_encryption_context diff --git a/test/unit/test_streaming_client_stream_encryptor.py b/test/unit/test_streaming_client_stream_encryptor.py index e43752689..4df79e146 100644 --- a/test/unit/test_streaming_client_stream_encryptor.py +++ b/test/unit/test_streaming_client_stream_encryptor.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit test suite for aws_encryption_sdk.streaming_client.StreamEncryptor""" +# noqa pylint: disable=too-many-lines import io import pytest @@ -451,6 +452,119 @@ def test_GIVEN_has_mpl_AND_has_MPLCMM_AND_uses_signer_WHEN_prep_message_THEN_sig encoding=serialization.Encoding.PEM ) + # Given: has MPL + @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") + def test_GIVEN_has_mpl_AND_encryption_materials_has_required_EC_keys_WHEN_prep_message_THEN_paritions_stored_and_required_EC( # noqa pylint: disable=line-too-long + self + ): + # Create explicit values to explicitly test logic in smaller cases + required_encryption_context_keys_values = [ + # Case of empty encryption context list is not allowed; + # if a list is provided, it must be non-empty. + # The MPL enforces this behavior on construction. + ["one_key"], + ["one_key", "two_key"], + ["one_key", "two_key", "red_key"], + ["one_key", "two_key", "red_key", "blue_key"], + ] + + encryption_context_values = [ + {}, + {"one_key": "some_value"}, + { + "one_key": "some_value", + "two_key": "some_other_value", + }, + { + "one_key": "some_value", + "two_key": "some_other_value", + "red_key": "some_red_value", + }, + { + "one_key": "some_value", + "two_key": "some_other_value", + "red_key": "some_red_value", + "blue_key": "some_blue_value", + } + ] + + self.mock_encryption_materials.algorithm = Algorithm.AES_128_GCM_IV12_TAG16 + + for required_encryption_context_keys in required_encryption_context_keys_values: + + # Given: encryption context has required_encryption_context_keys + self.mock_encryption_materials.required_encryption_context_keys = \ + required_encryption_context_keys + + for encryption_context in encryption_context_values: + self.mock_encryption_materials.encryption_context = encryption_context + + test_encryptor = StreamEncryptor( + source=VALUES["data_128"], + materials_manager=self.mock_mpl_materials_manager, + frame_length=self.mock_frame_length, + algorithm=Algorithm.AES_128_GCM_IV12_TAG16, + commitment_policy=self.mock_commitment_policy, + signature_policy=self.mock_signature_policy, + ) + test_encryptor.content_type = ContentType.FRAMED_DATA + # When: prep_message + test_encryptor._prep_message() + + # Then: Assert correctness of partitioned EC + for k, v in encryption_context.items(): + # If a key is in required_encryption_context_keys, then + if k in required_encryption_context_keys: + # 1) Its EC is in the StreamEncryptor._required_encryption_context + assert k in test_encryptor._required_encryption_context + assert test_encryptor._required_encryption_context[k] == v + # 2) Its EC is NOT in the StreamEncryptor._stored_encryption_context + assert k not in test_encryptor._stored_encryption_context + # If a key is NOT in required_encryption_context_keys, then + else: + # 1) Its EC is NOT in the StreamEncryptor._required_encryption_context + assert k not in test_encryptor._required_encryption_context + # 2) Its EC is in the StreamEncryptor._stored_encryption_context + assert k in test_encryptor._stored_encryption_context + assert test_encryptor._stored_encryption_context[k] == v + + # Assert size(stored_EC) + size(required_EC) == size(EC) + # (i.e. every EC was sorted into one or the other) + assert len(test_encryptor._required_encryption_context) \ + + len(test_encryptor._stored_encryption_context) \ + == len(encryption_context) + + # Given: has MPL + @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") + def test_GIVEN_has_mpl_AND_encryption_materials_does_not_have_required_EC_keys_WHEN_prep_message_THEN_stored_EC_is_EC( # noqa pylint: disable=line-too-long + self + ): + + self.mock_encryption_materials.algorithm = Algorithm.AES_128_GCM_IV12_TAG16 + + mock_encryption_context = MagicMock(__class__=dict) + self.mock_encryption_materials.encryption_context = mock_encryption_context + # Given: encryption materials does not have required encryption context keys + # (MagicMock default is to "make up" "Some" value here; this deletes that value) + del self.mock_encryption_materials.required_encryption_context_keys + + test_encryptor = StreamEncryptor( + source=VALUES["data_128"], + materials_manager=self.mock_mpl_materials_manager, + frame_length=self.mock_frame_length, + algorithm=Algorithm.AES_128_GCM_IV12_TAG16, + commitment_policy=self.mock_commitment_policy, + signature_policy=self.mock_signature_policy, + ) + test_encryptor.content_type = ContentType.FRAMED_DATA + # When: prep_message + test_encryptor._prep_message() + + # Then: _stored_encryption_context is the provided encryption_context + assert test_encryptor._stored_encryption_context == mock_encryption_context + # Then: _required_encryption_context is None + assert test_encryptor._required_encryption_context is None + def test_prep_message_no_signer(self): self.mock_encryption_materials.algorithm = Algorithm.AES_128_GCM_IV12_TAG16 test_encryptor = StreamEncryptor( @@ -575,6 +689,53 @@ def test_write_header(self): ) assert test_encryptor.output_buffer == b"1234567890" + @patch("aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context") + # Given: has MPL + @pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation") + def test_GIVEN_has_mpl_AND_has_required_EC_WHEN_write_header_THEN_adds_serialized_required_ec_to_header_auth( + self, + serialize_encryption_context + ): + self.mock_serialize_header.return_value = b"12345" + self.mock_serialize_header_auth.return_value = b"67890" + pt_stream = io.BytesIO(self.plaintext) + test_encryptor = StreamEncryptor( + source=pt_stream, + materials_manager=self.mock_materials_manager, + algorithm=aws_encryption_sdk.internal.defaults.ALGORITHM, + frame_length=self.mock_frame_length, + commitment_policy=self.mock_commitment_policy, + signature_policy=self.mock_signature_policy, + ) + test_encryptor.signer = sentinel.signer + test_encryptor.content_type = sentinel.content_type + test_encryptor._header = sentinel.header + sentinel.header.version = SerializationVersion.V1 + test_encryptor.output_buffer = b"" + test_encryptor._encryption_materials = self.mock_encryption_materials + test_encryptor._derived_data_key = sentinel.derived_data_key + + # Given: StreamEncryptor has _required_encryption_context + mock_required_ec = MagicMock(__class__=dict) + test_encryptor._required_encryption_context = mock_required_ec + mock_serialized_required_ec = MagicMock(__class__=bytes) + serialize_encryption_context.return_value = mock_serialized_required_ec + + # When: _write_header() + test_encryptor._write_header() + + self.mock_serialize_header.assert_called_once_with(header=test_encryptor._header, signer=sentinel.signer) + self.mock_serialize_header_auth.assert_called_once_with( + version=sentinel.header.version, + algorithm=self.mock_encryption_materials.algorithm, + header=b"12345", + data_encryption_key=sentinel.derived_data_key, + signer=sentinel.signer, + # Then: Pass serialized required EC to serialize_header_auth + required_ec_bytes=mock_serialized_required_ec, + ) + assert test_encryptor.output_buffer == b"1234567890" + @patch("aws_encryption_sdk.streaming_client.non_framed_body_iv") def test_prep_non_framed(self, mock_non_framed_iv): self.mock_serialize_non_framed_open.return_value = b"1234567890" diff --git a/test_vector_handlers/test/aws-crypto-tools-test-vector-framework b/test_vector_handlers/test/aws-crypto-tools-test-vector-framework index c3d73fae2..9eb2fcbbe 160000 --- a/test_vector_handlers/test/aws-crypto-tools-test-vector-framework +++ b/test_vector_handlers/test/aws-crypto-tools-test-vector-framework @@ -1 +1 @@ -Subproject commit c3d73fae260fd9e9cc9e746f09a7ffbab83576e2 +Subproject commit 9eb2fcbbe47ab30c29d6ad9a8125b1064e0db42a