Skip to content

fix: expose the KMS keyring key namespace value for public access #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/aws_encryption_sdk/keyrings/aws_kms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@
# We only actually need these imports when running the mypy checks
pass

__all__ = ("KmsKeyring",)
__all__ = ("KmsKeyring", "KEY_NAMESPACE")

_LOGGER = logging.getLogger(__name__)
_PROVIDER_ID = "aws-kms"
_GENERATE_FLAGS = {KeyringTraceFlag.GENERATED_DATA_KEY}
_ENCRYPT_FLAGS = {KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT}
_DECRYPT_FLAGS = {KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT}

#: Key namespace used for all encrypted data keys created by the KMS keyring.
KEY_NAMESPACE = "aws-kms"


@attr.s
class KmsKeyring(Keyring):
Expand Down Expand Up @@ -179,7 +181,7 @@ class _AwsKmsSingleCmkKeyring(Keyring):

def on_encrypt(self, encryption_materials):
# type: (EncryptionMaterials) -> EncryptionMaterials
trace_info = MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=self._key_id)
trace_info = MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=self._key_id)
new_materials = encryption_materials
try:
if new_materials.data_encryption_key is None:
Expand Down Expand Up @@ -221,7 +223,7 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys):
return new_materials

if (
edk.key_provider.provider_id == _PROVIDER_ID
edk.key_provider.provider_id == KEY_NAMESPACE
and edk.key_provider.key_info.decode("utf-8") == self._key_id
):
new_materials = _try_aws_kms_decrypt(
Expand Down Expand Up @@ -265,7 +267,7 @@ def on_decrypt(self, decryption_materials, encrypted_data_keys):
if new_materials.data_encryption_key is not None:
return new_materials

if edk.key_provider.provider_id == _PROVIDER_ID:
if edk.key_provider.provider_id == KEY_NAMESPACE:
new_materials = _try_aws_kms_decrypt(
client_supplier=self._client_supplier,
decryption_materials=new_materials,
Expand Down Expand Up @@ -327,7 +329,7 @@ def _do_aws_kms_decrypt(client_supplier, key_name, encrypted_data_key, encryptio
" actual '{actual}' != expected '{expected}'".format(actual=response_key_id, expected=key_name)
)
return RawDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=response_key_id), data_key=response["Plaintext"]
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=response_key_id), data_key=response["Plaintext"]
)


Expand All @@ -346,7 +348,7 @@ def _do_aws_kms_encrypt(client_supplier, key_name, plaintext_data_key, encryptio
GrantTokens=grant_tokens,
)
return EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=response["KeyId"]),
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=response["KeyId"]),
encrypted_data_key=response["CiphertextBlob"],
)

Expand All @@ -368,7 +370,7 @@ def _do_aws_kms_generate_data_key(client_supplier, key_name, encryption_context,
EncryptionContext=encryption_context,
GrantTokens=grant_tokens,
)
provider = MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=response["KeyId"])
provider = MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=response["KeyId"])
plaintext_key = RawDataKey(key_provider=provider, data_key=response["Plaintext"])
encrypted_key = EncryptedDataKey(key_provider=provider, encrypted_data_key=response["CiphertextBlob"])
return plaintext_key, encrypted_key
Expand Down
34 changes: 17 additions & 17 deletions test/functional/keyrings/aws_kms/test_aws_kms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aws_encryption_sdk.identifiers import KeyringTraceFlag
from aws_encryption_sdk.internal.defaults import ALGORITHM
from aws_encryption_sdk.keyrings.aws_kms import (
_PROVIDER_ID,
KEY_NAMESPACE,
KmsKeyring,
_AwsKmsDiscoveryKeyring,
_AwsKmsSingleCmkKeyring,
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_aws_kms_single_cmk_keyring_on_encrypt_empty_materials(fake_generator):
assert len(result_materials.encrypted_data_keys) == 1

generator_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=fake_generator), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=fake_generator), result_materials.keyring_trace
)

assert KeyringTraceFlag.GENERATED_DATA_KEY in generator_flags
Expand All @@ -84,7 +84,7 @@ def test_aws_kms_single_cmk_keyring_on_encrypt_existing_data_key(fake_generator)
assert len(result_materials.encrypted_data_keys) == 1

generator_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=fake_generator), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=fake_generator), result_materials.keyring_trace
)

assert KeyringTraceFlag.GENERATED_DATA_KEY not in generator_flags
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_aws_kms_single_cmk_keyring_on_decrypt_existing_datakey(caplog):
decryption_materials=initial_materials,
encrypted_data_keys=(
EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"foo"), encrypted_data_key=b"bar"
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=b"foo"), encrypted_data_key=b"bar"
),
),
)
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_aws_kms_single_cmk_keyring_on_decrypt_single_cmk(fake_generator):
assert result_materials.data_encryption_key is not None

generator_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=fake_generator), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=fake_generator), result_materials.keyring_trace
)

assert KeyringTraceFlag.DECRYPTED_DATA_KEY in generator_flags
Expand All @@ -180,12 +180,12 @@ def test_aws_kms_single_cmk_keyring_on_decrypt_multiple_cmk(fake_generator_and_c
)

generator_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=generator), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=generator), result_materials.keyring_trace
)
assert len(generator_flags) == 0

child_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=child), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=child), result_materials.keyring_trace
)

assert KeyringTraceFlag.DECRYPTED_DATA_KEY in child_flags
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_aws_kms_single_cmk_keyring_on_decrypt_fail(caplog):
decryption_materials=initial_materials,
encrypted_data_keys=(
EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"foo"), encrypted_data_key=b"bar"
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=b"foo"), encrypted_data_key=b"bar"
),
),
)
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_aws_kms_discovery_keyring_on_decrypt(encryption_materials_for_discovery
assert result_materials.data_encryption_key is not None

generator_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=generator_key_id), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=generator_key_id), result_materials.keyring_trace
)

assert KeyringTraceFlag.DECRYPTED_DATA_KEY in generator_flags
Expand All @@ -300,7 +300,7 @@ def test_aws_kms_discovery_keyring_on_decrypt_existing_data_key(caplog):
decryption_materials=initial_materials,
encrypted_data_keys=(
EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"foo"), encrypted_data_key=b"bar"
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=b"foo"), encrypted_data_key=b"bar"
),
),
)
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_aws_kms_discovery_keyring_on_decrypt_fail(caplog):
decryption_materials=initial_materials,
encrypted_data_keys=(
EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"bar"), encrypted_data_key=b"bar"
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=b"bar"), encrypted_data_key=b"bar"
),
),
)
Expand All @@ -365,7 +365,7 @@ def test_try_aws_kms_decrypt_succeed(fake_generator):
response = kms.encrypt(KeyId=fake_generator, Plaintext=plaintext, EncryptionContext=encryption_context)

encrypted_data_key = EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=response["KeyId"]),
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=response["KeyId"]),
encrypted_data_key=response["CiphertextBlob"],
)

Expand All @@ -381,7 +381,7 @@ def test_try_aws_kms_decrypt_succeed(fake_generator):
assert result_materials.data_encryption_key.data_key == plaintext

generator_flags = _matching_flags(
MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=fake_generator), result_materials.keyring_trace
MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=fake_generator), result_materials.keyring_trace
)

assert KeyringTraceFlag.DECRYPTED_DATA_KEY in generator_flags
Expand All @@ -394,7 +394,7 @@ def test_try_aws_kms_decrypt_error(caplog):
caplog.set_level(logging.DEBUG)

encrypted_data_key = EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=b"foo"), encrypted_data_key=b"bar"
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=b"foo"), encrypted_data_key=b"bar"
)

initial_decryption_materials = DecryptionMaterials(algorithm=ALGORITHM, encryption_context={},)
Expand All @@ -420,7 +420,7 @@ def test_do_aws_kms_decrypt(fake_generator):
response = kms.encrypt(KeyId=fake_generator, Plaintext=plaintext, EncryptionContext=encryption_context)

encrypted_data_key = EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=response["KeyId"]),
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=response["KeyId"]),
encrypted_data_key=response["CiphertextBlob"],
)

Expand All @@ -442,7 +442,7 @@ def test_do_aws_kms_decrypt_unexpected_key_id(fake_generator_and_child):
response = kms.encrypt(KeyId=encryptor, Plaintext=plaintext, EncryptionContext=encryption_context)

encrypted_data_key = EncryptedDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=response["KeyId"]),
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=response["KeyId"]),
encrypted_data_key=response["CiphertextBlob"],
)

Expand All @@ -466,7 +466,7 @@ def test_do_aws_kms_encrypt(fake_generator):
client_supplier=DefaultClientSupplier(),
key_name=fake_generator,
plaintext_data_key=RawDataKey(
key_provider=MasterKeyInfo(provider_id=_PROVIDER_ID, key_info=fake_generator), data_key=plaintext
key_provider=MasterKeyInfo(provider_id=KEY_NAMESPACE, key_info=fake_generator), data_key=plaintext
),
encryption_context=encryption_context,
grant_tokens=[],
Expand Down