diff --git a/aws_lambda_powertools/utilities/data_masking/base.py b/aws_lambda_powertools/utilities/data_masking/base.py index 1541e6f761b..0c58ee2a861 100644 --- a/aws_lambda_powertools/utilities/data_masking/base.py +++ b/aws_lambda_powertools/utilities/data_masking/base.py @@ -11,7 +11,7 @@ import logging import warnings from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any from jsonpath_ng.ext import parse @@ -23,6 +23,7 @@ from aws_lambda_powertools.warnings import PowertoolsUserWarning if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence from numbers import Number logger = logging.getLogger(__name__) diff --git a/aws_lambda_powertools/utilities/data_masking/provider/base.py b/aws_lambda_powertools/utilities/data_masking/provider/base.py index 16fa22d16b8..e8724c5a4de 100644 --- a/aws_lambda_powertools/utilities/data_masking/provider/base.py +++ b/aws_lambda_powertools/utilities/data_masking/provider/base.py @@ -3,13 +3,18 @@ import functools import json import re -from typing import Any, Callable +from typing import TYPE_CHECKING, Any from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING +if TYPE_CHECKING: + from collections.abc import Callable + PRESERVE_CHARS = set("-_. ") _regex_cache = {} +JSON_DUMPS_CALL = functools.partial(json.dumps, ensure_ascii=False) + class BaseProvider: """ @@ -49,7 +54,7 @@ def lambda_handler(event, context): def __init__( self, - json_serializer: Callable[..., str] = functools.partial(json.dumps, ensure_ascii=False), + json_serializer: Callable[..., str] = JSON_DUMPS_CALL, json_deserializer: Callable[[str], Any] = json.loads, ) -> None: self.json_serializer = json_serializer diff --git a/aws_lambda_powertools/utilities/data_masking/provider/kms/aws_encryption_sdk.py b/aws_lambda_powertools/utilities/data_masking/provider/kms/aws_encryption_sdk.py index 07d48efe569..c9c902d51cc 100644 --- a/aws_lambda_powertools/utilities/data_masking/provider/kms/aws_encryption_sdk.py +++ b/aws_lambda_powertools/utilities/data_masking/provider/kms/aws_encryption_sdk.py @@ -4,7 +4,7 @@ import json import logging from binascii import Error -from typing import Any, Callable +from typing import TYPE_CHECKING, Any import botocore from aws_encryption_sdk import ( @@ -41,8 +41,13 @@ ) from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider +if TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) +JSON_DUMPS_CALL = functools.partial(json.dumps, ensure_ascii=False) + class AWSEncryptionSDKProvider(BaseProvider): """ @@ -81,7 +86,7 @@ def __init__( max_cache_age_seconds: float = MAX_CACHE_AGE_SECONDS, max_messages_encrypted: int = MAX_MESSAGES_ENCRYPTED, max_bytes_encrypted: int = MAX_BYTES_ENCRYPTED, - json_serializer: Callable[..., str] = functools.partial(json.dumps, ensure_ascii=False), + json_serializer: Callable[..., str] = JSON_DUMPS_CALL, json_deserializer: Callable[[str], Any] = json.loads, ): super().__init__(json_serializer=json_serializer, json_deserializer=json_deserializer) diff --git a/tests/functional/data_masking/_aws_encryption_sdk/test_aws_encryption_sdk.py b/tests/functional/data_masking/_aws_encryption_sdk/test_aws_encryption_sdk.py index 63aca871e44..039e302bf93 100644 --- a/tests/functional/data_masking/_aws_encryption_sdk/test_aws_encryption_sdk.py +++ b/tests/functional/data_masking/_aws_encryption_sdk/test_aws_encryption_sdk.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import base64 import functools import json -from typing import Any, Callable, Union +from typing import TYPE_CHECKING, Any import pytest from aws_encryption_sdk.identifiers import Algorithm @@ -13,16 +15,21 @@ AWSEncryptionSDKProvider, ) +if TYPE_CHECKING: + from collections.abc import Callable + +JSON_DUMPS_CALL = functools.partial(json.dumps, ensure_ascii=False) + class FakeEncryptionKeyProvider(BaseProvider): def __init__( self, - json_serializer: Callable = functools.partial(json.dumps, ensure_ascii=False), + json_serializer: Callable = JSON_DUMPS_CALL, json_deserializer: Callable = json.loads, ) -> None: super().__init__(json_serializer, json_deserializer) - def encrypt(self, data: Union[bytes, str], **kwargs) -> str: + def encrypt(self, data: bytes | str, **kwargs) -> str: encoded_data: str = self.json_serializer(data) ciphertext = base64.b64encode(encoded_data.encode("utf-8")).decode() return ciphertext diff --git a/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py b/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py index b2bc94ed2ef..40ccf5280f8 100644 --- a/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py +++ b/tests/functional/data_masking/_pydantic/test_data_masking_with_pydantic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import pytest diff --git a/tests/functional/data_masking/conftest.py b/tests/functional/data_masking/conftest.py index f73ccca4113..15ce865abfa 100644 --- a/tests/functional/data_masking/conftest.py +++ b/tests/functional/data_masking/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pytest_socket import disable_socket diff --git a/tests/functional/data_masking/required_dependencies/test_erase_data_masking.py b/tests/functional/data_masking/required_dependencies/test_erase_data_masking.py index 12ffd054376..6aac48927da 100644 --- a/tests/functional/data_masking/required_dependencies/test_erase_data_masking.py +++ b/tests/functional/data_masking/required_dependencies/test_erase_data_masking.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import pytest diff --git a/tests/unit/data_masking/_aws_encryption_sdk/test_kms_provider.py b/tests/unit/data_masking/_aws_encryption_sdk/test_kms_provider.py index 5fe9b2e53ed..d736fafe6b6 100644 --- a/tests/unit/data_masking/_aws_encryption_sdk/test_kms_provider.py +++ b/tests/unit/data_masking/_aws_encryption_sdk/test_kms_provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from aws_lambda_powertools.utilities.data_masking.exceptions import ( diff --git a/tests/unit/data_masking/required_dependencies/test_base_functions.py b/tests/unit/data_masking/required_dependencies/test_base_functions.py index 1af532967c7..9ce3de65cfb 100644 --- a/tests/unit/data_masking/required_dependencies/test_base_functions.py +++ b/tests/unit/data_masking/required_dependencies/test_base_functions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from aws_lambda_powertools.utilities.data_masking.base import DataMasking