Skip to content

Commit 7883a48

Browse files
committed
Revised singleton class to allow for one instance per different configuration
1 parent 7483d46 commit 7883a48

File tree

8 files changed

+70
-40
lines changed

8 files changed

+70
-40
lines changed

aws_lambda_powertools/utilities/data_masking/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Union, Optional
2+
from typing import Optional, Union
33

44
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
55

@@ -94,7 +94,7 @@ def _apply_action_to_fields(self, data: Union[dict, str], fields, action, **prov
9494
else:
9595
raise TypeError(
9696
"Unsupported data type. The 'data' parameter must be a dictionary or a JSON string "
97-
"representation of a dictionary."
97+
"representation of a dictionary.",
9898
)
9999

100100
for field in fields:

aws_lambda_powertools/utilities/data_masking/providers/aws_encryption_sdk.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
LocalCryptoMaterialsCache,
99
StrictAwsKmsMasterKeyProvider,
1010
)
11-
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
11+
1212
from aws_lambda_powertools.shared.user_agent import register_feature_to_botocore_session
13+
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
1314

1415

1516
class ContextMismatchError(Exception):
@@ -18,25 +19,28 @@ def __init__(self, key):
1819
self.key = key
1920

2021

21-
class SingletonMeta(type):
22-
"""Metaclass to cache class instances to optimize encryption"""
22+
class Singleton:
23+
_instances: Dict[Any, "AwsEncryptionSdkProvider"] = {}
2324

24-
_instances: Dict["AwsEncryptionSdkProvider", Any] = {}
25+
def __new__(cls, *args, **kwargs):
26+
# Generate a unique key based on the configuration
27+
# Create a tuple by iterating through the values in kwargs, sorting them,
28+
# and then adding them to the tuple.
29+
config_key = tuple(v for value in kwargs.values() for v in sorted(value))
2530

26-
def __call__(cls, *args, **provider_options):
27-
if cls not in cls._instances:
28-
instance = super().__call__(*args, **provider_options)
29-
cls._instances[cls] = instance
30-
return cls._instances[cls]
31+
if config_key not in cls._instances:
32+
cls._instances[config_key] = super(Singleton, cls).__new__(cls, *args)
33+
print("in if class instances:", cls._instances)
34+
return cls._instances[config_key]
3135

3236

3337
CACHE_CAPACITY: int = 100
34-
MAX_ENTRY_AGE_SECONDS: float = 300.0
35-
MAX_MESSAGES: int = 200
38+
MAX_CACHE_AGE_SECONDS: float = 300.0
39+
MAX_MESSAGES_ENCRYPTED: int = 200
3640
# NOTE: You can also set max messages/bytes per data key
3741

3842

39-
class AwsEncryptionSdkProvider(BaseProvider):
43+
class AwsEncryptionSdkProvider(BaseProvider, Singleton):
4044
"""
4145
The AwsEncryptionSdkProvider is to be used as a Provider for the Datamasking class.
4246
@@ -57,8 +61,8 @@ def __init__(
5761
keys: List[str],
5862
client: Optional[EncryptionSDKClient] = None,
5963
local_cache_capacity: Optional[int] = CACHE_CAPACITY,
60-
max_cache_age_seconds: Optional[float] = MAX_ENTRY_AGE_SECONDS,
61-
max_messages: Optional[int] = MAX_MESSAGES,
64+
max_cache_age_seconds: Optional[float] = MAX_CACHE_AGE_SECONDS,
65+
max_messages_encrypted: Optional[int] = MAX_MESSAGES_ENCRYPTED,
6266
):
6367
self.client = client or EncryptionSDKClient()
6468
self.keys = keys
@@ -68,7 +72,7 @@ def __init__(
6872
master_key_provider=self.key_provider,
6973
cache=self.cache,
7074
max_age=max_cache_age_seconds,
71-
max_messages_encrypted=max_messages,
75+
max_messages_encrypted=max_messages_encrypted,
7276
)
7377

7478
def encrypt(self, data: Union[bytes, str], **provider_options) -> str:
@@ -112,7 +116,9 @@ def decrypt(self, data: str, **provider_options) -> bytes:
112116
expected_context = provider_options.pop("encryption_context", {})
113117

114118
ciphertext, decryptor_header = self.client.decrypt(
115-
source=ciphertext_decoded, key_provider=self.key_provider, **provider_options
119+
source=ciphertext_decoded,
120+
key_provider=self.key_provider,
121+
**provider_options,
116122
)
117123

118124
for key, value in expected_context.items():

tests/e2e/data_masking/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
23
from tests.e2e.data_masking.infrastructure import DataMaskingStack
34

45

tests/e2e/data_masking/handlers/basic_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from aws_lambda_powertools import Logger
12
from aws_lambda_powertools.utilities.data_masking.base import DataMasking
23
from aws_lambda_powertools.utilities.data_masking.providers.aws_encryption_sdk import AwsEncryptionSdkProvider
3-
from aws_lambda_powertools import Logger
44

55
logger = Logger()
66

tests/e2e/data_masking/infrastructure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import aws_cdk.aws_kms as kms
22
from aws_cdk import CfnOutput, Duration
3+
34
from tests.e2e.utils.infrastructure import BaseInfrastructure
45

56

tests/e2e/data_masking/test_data_masking.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import json
22
from uuid import uuid4
3-
from aws_encryption_sdk.exceptions import DecryptKeyError
3+
44
import pytest
5-
from tests.e2e.utils import data_fetcher
5+
from aws_encryption_sdk.exceptions import DecryptKeyError
6+
67
from aws_lambda_powertools.utilities.data_masking.base import DataMasking
78
from aws_lambda_powertools.utilities.data_masking.providers.aws_encryption_sdk import (
89
AwsEncryptionSdkProvider,
910
ContextMismatchError,
1011
)
12+
from tests.e2e.utils import data_fetcher
1113

1214

1315
@pytest.fixture
@@ -93,7 +95,6 @@ def test_encryption_no_context_fail(data_masker):
9395
data_masker.decrypt(encrypted_data, encryption_context={"this": "is_secure"})
9496

9597

96-
# TODO: metaclass?
9798
@pytest.mark.xdist_group(name="data_masking")
9899
def test_encryption_decryption_key_mismatch(data_masker, kms_key2_arn):
99100
# GIVEN an instantiation of DataMasking with the AWS encryption provider with a certain key
@@ -109,6 +110,21 @@ def test_encryption_decryption_key_mismatch(data_masker, kms_key2_arn):
109110
data_masker_key2.decrypt(encrypted_data)
110111

111112

113+
def test_encryption_provider_singleton(data_masker, kms_key1_arn, kms_key2_arn):
114+
data_masker_2 = DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key1_arn]))
115+
assert data_masker.provider is data_masker_2.provider
116+
117+
# WHEN encrypting and then decrypting the encrypted data
118+
encrypted_data = data_masker.encrypt("string")
119+
decrypted_data = data_masker_2.decrypt(encrypted_data)
120+
121+
# THEN the result is the original input data
122+
assert decrypted_data == bytes("string", "utf-8")
123+
124+
data_masker_3 = DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key2_arn]))
125+
assert data_masker_2.provider is not data_masker_3.provider
126+
127+
112128
@pytest.mark.xdist_group(name="data_masking")
113129
def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn):
114130
# GIVEN an instantiation of DataMasking with the AWS encryption provider
@@ -132,6 +148,7 @@ def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn)
132148
decrypted_data = data_masker.decrypt(encrypted_data)
133149
assert decrypted_data == value
134150

151+
135152
# NOTE: This test is failing currently, need to find a fix for building correct dependencies
136153
@pytest.mark.xdist_group(name="data_masking")
137154
def test_encryption_in_handler(basic_handler_fn_arn, kms_key1_arn):

tests/functional/data_masking/test_aws_encryption_sdk.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
from unittest.mock import patch
2-
31
import pytest
42

53
from aws_lambda_powertools.utilities.data_masking.base import DataMasking
64
from aws_lambda_powertools.utilities.data_masking.providers.aws_encryption_sdk import AwsEncryptionSdkProvider
75
from tests.unit.data_masking.setup import *
86

9-
107
AWS_SDK_KEY = "arn:aws:kms:us-west-2:683517028648:key/269301eb-81eb-4067-ac72-98e8e49bf2b3"
118

129

tests/unit/data_masking/test_data_masking.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
2+
23
import pytest
34
from itsdangerous.url_safe import URLSafeSerializer
4-
from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING
5+
56
from aws_lambda_powertools.utilities.data_masking.base import DataMasking
7+
from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING
68
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
79

810

@@ -102,7 +104,7 @@ def test_mask_dict(data_masker):
102104
"a": {
103105
"1": {"None": "hello", "four": "world"},
104106
"b": {"3": {"4": "goodbye", "e": "world"}},
105-
}
107+
},
106108
}
107109

108110
# WHEN mask is called with no fields argument
@@ -118,15 +120,18 @@ def test_mask_dict_with_fields(data_masker):
118120
"a": {
119121
"1": {"None": "hello", "four": "world"},
120122
"b": {"3": {"4": "goodbye", "e": "world"}},
121-
}
123+
},
122124
}
123125

124126
# WHEN mask is called with a list of fields specified
125127
masked_string = data_masker.mask(data, fields=["a.1.None", "a.b.3.4"])
126128

127129
# THEN the result is only the specified fields are masked
128130
assert masked_string == {
129-
"a": {"1": {"None": DATA_MASKING_STRING, "four": "world"}, "b": {"3": {"4": DATA_MASKING_STRING, "e": "world"}}}
131+
"a": {
132+
"1": {"None": DATA_MASKING_STRING, "four": "world"},
133+
"b": {"3": {"4": DATA_MASKING_STRING, "e": "world"}},
134+
},
130135
}
131136

132137

@@ -137,16 +142,19 @@ def test_mask_json_dict_with_fields(data_masker):
137142
"a": {
138143
"1": {"None": "hello", "four": "world"},
139144
"b": {"3": {"4": "goodbye", "e": "world"}},
140-
}
141-
}
145+
},
146+
},
142147
)
143148

144149
# WHEN mask is called with a list of fields specified
145150
masked_json_string = data_masker.mask(data, fields=["a.1.None", "a.b.3.4"])
146151

147152
# THEN the result is only the specified fields are masked
148153
assert masked_json_string == {
149-
"a": {"1": {"None": DATA_MASKING_STRING, "four": "world"}, "b": {"3": {"4": DATA_MASKING_STRING, "e": "world"}}}
154+
"a": {
155+
"1": {"None": DATA_MASKING_STRING, "four": "world"},
156+
"b": {"3": {"4": DATA_MASKING_STRING, "e": "world"}},
157+
},
150158
}
151159

152160

@@ -180,7 +188,7 @@ def test_encrypt_decrypt_bool(custom_data_masker):
180188
decrypted_data = custom_data_masker.decrypt(encrypted_data)
181189

182190
# THEN the result is the original input data
183-
assert decrypted_data == True
191+
assert decrypted_data is True
184192

185193

186194
def test_encrypt_decrypt_none(custom_data_masker):
@@ -191,7 +199,7 @@ def test_encrypt_decrypt_none(custom_data_masker):
191199
decrypted_data = custom_data_masker.decrypt(encrypted_data)
192200

193201
# THEN the result is the original input data
194-
assert decrypted_data == None
202+
assert decrypted_data is None
195203

196204

197205
def test_encrypt_decrypt_str(custom_data_masker):
@@ -223,7 +231,7 @@ def test_dict_encryption_with_fields(custom_data_masker):
223231
"a": {
224232
"1": {"None": "hello", "four": "world"},
225233
"b": {"3": {"4": "goodbye", "e": "world"}},
226-
}
234+
},
227235
}
228236

229237
# WHEN encrypting and decrypting the data with a list of fields
@@ -242,8 +250,8 @@ def test_json_encryption_with_fields(custom_data_masker):
242250
"a": {
243251
"1": {"None": "hello", "four": "world"},
244252
"b": {"3": {"4": "goodbye", "e": "world"}},
245-
}
246-
}
253+
},
254+
},
247255
)
248256

249257
# WHEN encrypting and decrypting a json representation of a dictionary with a list of fields
@@ -330,7 +338,7 @@ def test_parsing_nonexistent_fields(data_masker):
330338
"3": {
331339
"1": {"None": "hello", "four": "world"},
332340
"4": {"33": {"5": "goodbye", "e": "world"}},
333-
}
341+
},
334342
}
335343

336344
# WHEN attempting to pass in fields that do not exist in the input data
@@ -347,7 +355,7 @@ def test_parsing_nonstring_fields(data_masker):
347355
"3": {
348356
"1": {"None": "hello", "four": "world"},
349357
"4": {"33": {"5": "goodbye", "e": "world"}},
350-
}
358+
},
351359
}
352360

353361
# WHEN attempting to pass in a list of fields that are not strings
@@ -365,7 +373,7 @@ def test_parsing_nonstring_keys_and_fields(data_masker):
365373
3: {
366374
"1": {"None": "hello", "four": "world"},
367375
4: {"33": {"5": "goodbye", "e": "world"}},
368-
}
376+
},
369377
}
370378
masked = data_masker.mask(data, fields=[3.4])
371379

0 commit comments

Comments
 (0)