Skip to content

Commit b15b866

Browse files
committed
Added custom exception for enc_context mismatch, used pytest fixtures in unit tests
1 parent 2955c9c commit b15b866

File tree

3 files changed

+118
-97
lines changed

3 files changed

+118
-97
lines changed

aws_lambda_powertools/utilities/data_masking/providers/aws_encryption_sdk.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88
LocalCryptoMaterialsCache,
99
StrictAwsKmsMasterKeyProvider,
1010
)
11-
1211
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
1312
from aws_lambda_powertools.shared.user_agent import register_feature_to_botocore_session
1413

1514

15+
class ContextMismatchError(Exception):
16+
def __init__(self, key):
17+
super().__init__(f"Encryption Context does not match expected value for key: {key}")
18+
self.key = key
19+
20+
1621
class SingletonMeta(type):
1722
"""Metaclass to cache class instances to optimize encryption"""
1823

@@ -32,27 +37,70 @@ def __call__(cls, *args, **provider_options):
3237

3338

3439
class AwsEncryptionSdkProvider(BaseProvider):
35-
cache = LocalCryptoMaterialsCache(CACHE_CAPACITY)
40+
"""
41+
The AwsEncryptionSdkProvider is to be used as a Provider for the Datamasking class.
42+
Example:
43+
>>> data_masker = DataMasking(provider=AwsEncryptionSdkProvider(keys=[keyARN1, keyARN2,...,]))
44+
>>> encrypted_data = data_masker.encrypt("a string")
45+
"encrptedBase64String"
46+
>>> decrypted_data = data_masker.decrypt(encrypted_data)
47+
"a string"
48+
"""
49+
3650
session = botocore.session.Session()
3751
register_feature_to_botocore_session(session, "data-masking")
3852

39-
def __init__(self, keys: List[str], client: Optional[EncryptionSDKClient] = None):
53+
def __init__(
54+
self,
55+
keys: List[str],
56+
client: Optional[EncryptionSDKClient] = None,
57+
local_cache_capacity: Optional[int] = CACHE_CAPACITY,
58+
max_cache_age_seconds: Optional[float] = MAX_ENTRY_AGE_SECONDS,
59+
max_messages: Optional[int] = MAX_MESSAGES,
60+
):
4061
self.client = client or EncryptionSDKClient()
4162
self.keys = keys
63+
self.cache = LocalCryptoMaterialsCache(local_cache_capacity)
4264
self.key_provider = StrictAwsKmsMasterKeyProvider(key_ids=self.keys, botocore_session=self.session)
4365
self.cache_cmm = CachingCryptoMaterialsManager(
4466
master_key_provider=self.key_provider,
4567
cache=self.cache,
46-
max_age=MAX_ENTRY_AGE_SECONDS,
47-
max_messages_encrypted=MAX_MESSAGES,
68+
max_age=max_cache_age_seconds,
69+
max_messages_encrypted=max_messages,
4870
)
4971

5072
def encrypt(self, data: Union[bytes, str], **provider_options) -> str:
73+
"""
74+
Encrypt data using the AwsEncryptionSdkProvider.
75+
76+
Parameters:
77+
- data (Union[bytes, str]):
78+
The data to be encrypted.
79+
- provider_options:
80+
Additional options for the aws_encryption_sdk.EncryptionSDKClient
81+
82+
Returns:
83+
- ciphertext (str):
84+
The encrypted data, as a base64-encoded string.
85+
"""
5186
ciphertext, _ = self.client.encrypt(source=data, materials_manager=self.cache_cmm, **provider_options)
5287
ciphertext = base64.b64encode(ciphertext).decode()
5388
return ciphertext
5489

5590
def decrypt(self, data: str, **provider_options) -> bytes:
91+
"""
92+
Decrypt data using AwsEncryptionSdkProvider.
93+
94+
Parameters:
95+
- data (Union[bytes, str]):
96+
The encrypted data, as a base64-encoded string.
97+
- provider_options:
98+
Additional options for the aws_encryption_sdk.EncryptionSDKClient
99+
100+
Returns:
101+
- ciphertext (bytes):
102+
The decrypted data in bytes
103+
"""
56104
ciphertext_decoded = base64.b64decode(data)
57105

58106
expected_context = provider_options.pop("encryption_context", {})
@@ -63,6 +111,6 @@ def decrypt(self, data: str, **provider_options) -> bytes:
63111

64112
for key, value in expected_context.items():
65113
if decryptor_header.encryption_context.get(key) != value:
66-
raise ValueError(f"Encryption Context does not match expected value for key: {key}")
114+
raise ContextMismatchError(key)
67115

68116
return ciphertext

tests/e2e/data_masking/test_data_masking.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import pytest
55
from tests.e2e.utils import data_fetcher
66
from aws_lambda_powertools.utilities.data_masking.base import DataMasking
7-
from aws_lambda_powertools.utilities.data_masking.providers.aws_encryption_sdk import AwsEncryptionSdkProvider
7+
from aws_lambda_powertools.utilities.data_masking.providers.aws_encryption_sdk import (
8+
AwsEncryptionSdkProvider,
9+
ContextMismatchError,
10+
)
811

912

1013
@pytest.fixture
@@ -52,26 +55,27 @@ def test_encryption_context(data_masker):
5255
# GIVEN an instantiation of DataMasking with the AWS encryption provider
5356

5457
value = bytes(str([1, 2, "string", 4.5]), "utf-8")
58+
context = {"this": "is_secure"}
5559

5660
# WHEN encrypting and then decrypting the encrypted data with an encryption_context
57-
encrypted_data = data_masker.encrypt(value, encryption_context={"this": "is_secure"})
58-
decrypted_data = data_masker.decrypt(encrypted_data, encryption_context={"this": "is_secure"})
61+
encrypted_data = data_masker.encrypt(value, encryption_context=context)
62+
decrypted_data = data_masker.decrypt(encrypted_data, encryption_context=context)
5963

6064
# THEN the result is the original input data
6165
assert decrypted_data == value
6266

6367

6468
@pytest.mark.xdist_group(name="data_masking")
65-
def test_encryption_diff_context_fail(data_masker):
69+
def test_encryption_context_mismatch(data_masker):
6670
# GIVEN an instantiation of DataMasking with the AWS encryption provider
6771

6872
value = bytes(str([1, 2, "string", 4.5]), "utf-8")
6973

7074
# WHEN encrypting with a encryption_context
7175
encrypted_data = data_masker.encrypt(value, encryption_context={"this": "is_secure"})
7276

73-
# THEN decrypting with a different encryption_context should raise a ValueError
74-
with pytest.raises(ValueError):
77+
# THEN decrypting with a different encryption_context should raise a ContextMismatchError
78+
with pytest.raises(ContextMismatchError):
7579
data_masker.decrypt(encrypted_data, encryption_context={"not": "same_context"})
7680

7781

@@ -84,14 +88,14 @@ def test_encryption_no_context_fail(data_masker):
8488
# WHEN encrypting with no encryption_context
8589
encrypted_data = data_masker.encrypt(value)
8690

87-
# THEN decrypting with an encryption_context should raise a ValueError
88-
with pytest.raises(ValueError):
91+
# THEN decrypting with an encryption_context should raise a ContextMismatchError
92+
with pytest.raises(ContextMismatchError):
8993
data_masker.decrypt(encrypted_data, encryption_context={"this": "is_secure"})
9094

9195

9296
# TODO: metaclass?
9397
@pytest.mark.xdist_group(name="data_masking")
94-
def test_encryption_key_fail(data_masker, kms_key2_arn):
98+
def test_encryption_decryption_key_mismatch(data_masker, kms_key2_arn):
9599
# GIVEN an instantiation of DataMasking with the AWS encryption provider with a certain key
96100

97101
# WHEN encrypting and then decrypting the encrypted data
@@ -106,7 +110,7 @@ def test_encryption_key_fail(data_masker, kms_key2_arn):
106110

107111

108112
@pytest.mark.xdist_group(name="data_masking")
109-
def test_encrypted_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn):
113+
def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn):
110114
# GIVEN an instantiation of DataMasking with the AWS encryption provider
111115

112116
# WHEN encrypting a value and logging it

0 commit comments

Comments
 (0)