diff --git a/aws_lambda_powertools/utilities/feature_flags/appconfig.py b/aws_lambda_powertools/utilities/feature_flags/appconfig.py index dd581df9e22..9871df7314f 100644 --- a/aws_lambda_powertools/utilities/feature_flags/appconfig.py +++ b/aws_lambda_powertools/utilities/feature_flags/appconfig.py @@ -8,6 +8,7 @@ from aws_lambda_powertools.utilities.parameters import AppConfigProvider, GetParameterError, TransformParameterError from ... import Logger +from ...shared.types import AnyCallableT from .base import StoreProvider from .exceptions import ConfigurationStoreError, StoreClientError @@ -22,6 +23,7 @@ def __init__( name: str, max_age: int = 5, sdk_config: Optional[Config] = None, + transform: Optional[Union[AnyCallableT, str]] = TRANSFORM_TYPE, envelope: Optional[str] = "", jmespath_options: Optional[Dict] = None, logger: Optional[Union[logging.Logger, Logger]] = None, @@ -54,6 +56,7 @@ def __init__( self.name = name self.cache_seconds = max_age self.config = sdk_config + self.transform = transform self.envelope = envelope self.jmespath_options = jmespath_options self._conf_store = AppConfigProvider(environment=environment, application=application, config=sdk_config) @@ -70,7 +73,7 @@ def get_raw_configuration(self) -> Dict[str, Any]: dict, self._conf_store.get( name=self.name, - transform=TRANSFORM_TYPE, + transform=self.transform, max_age=self.cache_seconds, ), ) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index b059a3b2483..6d5876ea94d 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from typing import Any, Dict, Optional, Tuple, Union +from ...shared.types import AnyCallableT from .exceptions import GetParameterError, TransformParameterError DEFAULT_MAX_AGE_SECS = 5 @@ -34,14 +35,14 @@ def __init__(self): self.store = {} - def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool: + def _has_not_expired(self, key: Tuple[str, Optional[Union[AnyCallableT, str]]]) -> bool: return key in self.store and self.store[key].ttl >= datetime.now() def get( self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: Optional[Union[AnyCallableT, str]] = None, force_fetch: bool = False, **sdk_options, ) -> Union[str, list, dict, bytes]: @@ -112,7 +113,7 @@ def get_multiple( self, path: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: Optional[Union[AnyCallableT, str]] = None, raise_on_transform_error: bool = False, force_fetch: bool = False, **sdk_options, @@ -217,7 +218,11 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[ return None -def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]: +def transform_value( + value: str, + transform: Union[AnyCallableT, str], + raise_on_transform_error: bool = True, +) -> Union[dict, bytes, None]: """ Apply a transform to a value @@ -238,7 +243,9 @@ def transform_value(value: str, transform: str, raise_on_transform_error: bool = """ try: - if transform == TRANSFORM_METHOD_JSON: + if callable(transform): + return transform(value) + elif transform == TRANSFORM_METHOD_JSON: return json.loads(value) elif transform == TRANSFORM_METHOD_BINARY: return base64.b64decode(value) diff --git a/tests/functional/feature_flags/test_feature_flags.py b/tests/functional/feature_flags/test_feature_flags.py index 8381dc6bf1d..82aceb235a9 100644 --- a/tests/functional/feature_flags/test_feature_flags.py +++ b/tests/functional/feature_flags/test_feature_flags.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Optional import pytest @@ -1197,3 +1198,30 @@ def test_flags_greater_than_or_equal_match_2(mocker, config): default=False, ) assert toggle == expected_value + + +def test_get_feature_toggle_appconfig_store_callable_transform(mocker, config): + + mock_schema = { + "ten_percent_off_campaign": { + "default": False, + }, + } + + mock_value = json.dumps(mock_schema) + + mocked__get_conf = mocker.patch("aws_lambda_powertools.utilities.parameters.AppConfigProvider._get") + mocked__get_conf.return_value = mock_value + + app_conf_fetcher = AppConfigStore( + environment="test_env", + application="test_app", + name="test_conf_name", + max_age=600, + sdk_config=config, + transform=json.loads, + ) + + feature_flags: FeatureFlags = FeatureFlags(store=app_conf_fetcher) + + assert feature_flags.get_configuration() == mock_schema diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 47fc5a0e982..3cf59ea15ef 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1138,6 +1138,54 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert "Incorrect padding" in str(excinfo) +def test_base_provider_get_transform_callable(mock_name, mock_value): + """ + Test BaseProvider.get() with a callable transform + """ + + mock_binary = mock_value.encode() + mock_data = base64.b16encode(mock_binary).decode() + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_data + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + value = provider.get(mock_name, transform=base64.b16decode) + + assert isinstance(value, bytes) + assert value == mock_binary + + +def test_base_provider_get_transform_callable_exception(mock_name): + """ + Test BaseProvider.get() with a callable transform that raises an exception + """ + + mock_data = "qw" + print(mock_data) + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_data + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + with pytest.raises(parameters.TransformParameterError) as excinfo: + provider.get(mock_name, transform=base64.b16decode) + + assert "Non-base16 digit found" in str(excinfo) + + def test_base_provider_get_multiple_transform_json(mock_name, mock_value): """ Test BaseProvider.get_multiple() with a json transform @@ -1281,6 +1329,79 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert "Incorrect padding" in str(excinfo) +def test_base_provider_get_multiple_transform_callable(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a callable transform + """ + + mock_binary = mock_value.encode() + mock_data = base64.b16encode(mock_binary).decode() + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data} + + provider = TestProvider() + + value = provider.get_multiple(mock_name, transform=base64.b16decode) + + assert isinstance(value, dict) + assert value["A"] == mock_binary + + +def test_base_provider_get_multiple_transform_callable_partial_failure(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a callable transform that contains a partial failure + """ + + mock_binary = mock_value.encode() + mock_data_a = base64.b16encode(mock_binary).decode() + mock_data_b = "qw" + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data_a, "B": mock_data_b} + + provider = TestProvider() + + value = provider.get_multiple(mock_name, transform=base64.b16decode) + + assert isinstance(value, dict) + assert value["A"] == mock_binary + assert value["B"] is None + + +def test_base_provider_get_multiple_transform_callable_exception(mock_name): + """ + Test BaseProvider.get_multiple() with a callable transform that raises an exception + """ + + mock_data = "qw" + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data} + + provider = TestProvider() + + with pytest.raises(parameters.TransformParameterError) as excinfo: + provider.get_multiple(mock_name, transform=base64.b16decode, raise_on_transform_error=True) + + assert "Non-base16 digit found" in str(excinfo) + + def test_base_provider_get_multiple_cached(mock_name, mock_value): """ Test BaseProvider.get_multiple() with cached values