Skip to content

feat(parameters): Allow callable transforms #894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aws_lambda_powertools.utilities.parameters import AppConfigProvider, GetParameterError, TransformParameterError

from ... import Logger
from ...shared.types import AnyCallableT
from .base import StoreProvider
from .exceptions import ConfigurationStoreError, StoreClientError

Expand All @@ -22,6 +23,7 @@ def __init__(
name: str,
max_age: int = 5,
sdk_config: Optional[Config] = None,
transform: Optional[Union[AnyCallableT, str]] = TRANSFORM_TYPE,
envelope: Optional[str] = "",
jmespath_options: Optional[Dict] = None,
logger: Optional[Union[logging.Logger, Logger]] = None,
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(
self.name = name
self.cache_seconds = max_age
self.config = sdk_config
self.transform = transform
self.envelope = envelope
self.jmespath_options = jmespath_options
self._conf_store = AppConfigProvider(environment=environment, application=application, config=sdk_config)
Expand All @@ -70,7 +73,7 @@ def get_raw_configuration(self) -> Dict[str, Any]:
dict,
self._conf_store.get(
name=self.name,
transform=TRANSFORM_TYPE,
transform=self.transform,
max_age=self.cache_seconds,
),
)
Expand Down
17 changes: 12 additions & 5 deletions aws_lambda_powertools/utilities/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Tuple, Union

from ...shared.types import AnyCallableT
from .exceptions import GetParameterError, TransformParameterError

DEFAULT_MAX_AGE_SECS = 5
Expand All @@ -34,14 +35,14 @@ def __init__(self):

self.store = {}

def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool:
def _has_not_expired(self, key: Tuple[str, Optional[Union[AnyCallableT, str]]]) -> bool:
return key in self.store and self.store[key].ttl >= datetime.now()

def get(
self,
name: str,
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
transform: Optional[Union[AnyCallableT, str]] = None,
force_fetch: bool = False,
**sdk_options,
) -> Union[str, list, dict, bytes]:
Expand Down Expand Up @@ -112,7 +113,7 @@ def get_multiple(
self,
path: str,
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
transform: Optional[Union[AnyCallableT, str]] = None,
raise_on_transform_error: bool = False,
force_fetch: bool = False,
**sdk_options,
Expand Down Expand Up @@ -217,7 +218,11 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[
return None


def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
def transform_value(
value: str,
transform: Union[AnyCallableT, str],
raise_on_transform_error: bool = True,
) -> Union[dict, bytes, None]:
"""
Apply a transform to a value

Expand All @@ -238,7 +243,9 @@ def transform_value(value: str, transform: str, raise_on_transform_error: bool =
"""

try:
if transform == TRANSFORM_METHOD_JSON:
if callable(transform):
return transform(value)
elif transform == TRANSFORM_METHOD_JSON:
return json.loads(value)
elif transform == TRANSFORM_METHOD_BINARY:
return base64.b64decode(value)
Expand Down
28 changes: 28 additions & 0 deletions tests/functional/feature_flags/test_feature_flags.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Dict, List, Optional

import pytest
Expand Down Expand Up @@ -1197,3 +1198,30 @@ def test_flags_greater_than_or_equal_match_2(mocker, config):
default=False,
)
assert toggle == expected_value


def test_get_feature_toggle_appconfig_store_callable_transform(mocker, config):

mock_schema = {
"ten_percent_off_campaign": {
"default": False,
},
}

mock_value = json.dumps(mock_schema)

mocked__get_conf = mocker.patch("aws_lambda_powertools.utilities.parameters.AppConfigProvider._get")
mocked__get_conf.return_value = mock_value

app_conf_fetcher = AppConfigStore(
environment="test_env",
application="test_app",
name="test_conf_name",
max_age=600,
sdk_config=config,
transform=json.loads,
)

feature_flags: FeatureFlags = FeatureFlags(store=app_conf_fetcher)

assert feature_flags.get_configuration() == mock_schema
121 changes: 121 additions & 0 deletions tests/functional/test_utilities_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,54 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert "Incorrect padding" in str(excinfo)


def test_base_provider_get_transform_callable(mock_name, mock_value):
"""
Test BaseProvider.get() with a callable transform
"""

mock_binary = mock_value.encode()
mock_data = base64.b16encode(mock_binary).decode()

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
assert name == mock_name
return mock_data

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
raise NotImplementedError()

provider = TestProvider()

value = provider.get(mock_name, transform=base64.b16decode)

assert isinstance(value, bytes)
assert value == mock_binary


def test_base_provider_get_transform_callable_exception(mock_name):
"""
Test BaseProvider.get() with a callable transform that raises an exception
"""

mock_data = "qw"
print(mock_data)

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
assert name == mock_name
return mock_data

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
raise NotImplementedError()

provider = TestProvider()

with pytest.raises(parameters.TransformParameterError) as excinfo:
provider.get(mock_name, transform=base64.b16decode)

assert "Non-base16 digit found" in str(excinfo)


def test_base_provider_get_multiple_transform_json(mock_name, mock_value):
"""
Test BaseProvider.get_multiple() with a json transform
Expand Down Expand Up @@ -1281,6 +1329,79 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert "Incorrect padding" in str(excinfo)


def test_base_provider_get_multiple_transform_callable(mock_name, mock_value):
"""
Test BaseProvider.get_multiple() with a callable transform
"""

mock_binary = mock_value.encode()
mock_data = base64.b16encode(mock_binary).decode()

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
raise NotImplementedError()

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert path == mock_name
return {"A": mock_data}

provider = TestProvider()

value = provider.get_multiple(mock_name, transform=base64.b16decode)

assert isinstance(value, dict)
assert value["A"] == mock_binary


def test_base_provider_get_multiple_transform_callable_partial_failure(mock_name, mock_value):
"""
Test BaseProvider.get_multiple() with a callable transform that contains a partial failure
"""

mock_binary = mock_value.encode()
mock_data_a = base64.b16encode(mock_binary).decode()
mock_data_b = "qw"

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
raise NotImplementedError()

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert path == mock_name
return {"A": mock_data_a, "B": mock_data_b}

provider = TestProvider()

value = provider.get_multiple(mock_name, transform=base64.b16decode)

assert isinstance(value, dict)
assert value["A"] == mock_binary
assert value["B"] is None


def test_base_provider_get_multiple_transform_callable_exception(mock_name):
"""
Test BaseProvider.get_multiple() with a callable transform that raises an exception
"""

mock_data = "qw"

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
raise NotImplementedError()

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert path == mock_name
return {"A": mock_data}

provider = TestProvider()

with pytest.raises(parameters.TransformParameterError) as excinfo:
provider.get_multiple(mock_name, transform=base64.b16decode, raise_on_transform_error=True)

assert "Non-base16 digit found" in str(excinfo)


def test_base_provider_get_multiple_cached(mock_name, mock_value):
"""
Test BaseProvider.get_multiple() with cached values
Expand Down