Skip to content

fix(idempotency): compute hashes once to guard against inadvert mutat… #1094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions aws_lambda_powertools/utilities/idempotency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(

persistence_store.configure(config, self.function.__name__)
self.persistence_store = persistence_store
self.hashed_idempotency_key = persistence_store.get_hashed_idempotency_key(self.data)
self.hashed_payload = persistence_store.get_hashed_payload(self.data)

def handle(self) -> Any:
"""
Expand All @@ -100,7 +102,7 @@ def _process_idempotency(self):
try:
# We call save_inprogress first as an optimization for the most common case where no idempotent record
# already exists. If it succeeds, there's no need to call get_record.
self.persistence_store.save_inprogress(data=self.data)
self.persistence_store.save_inprogress(hashed_key=self.hashed_idempotency_key, hashed_payload=self.hashed_payload)
except IdempotencyKeyError:
raise
except IdempotencyItemAlreadyExistsError:
Expand All @@ -122,7 +124,7 @@ def _get_idempotency_record(self) -> DataRecord:

"""
try:
data_record = self.persistence_store.get_record(data=self.data)
data_record = self.persistence_store.get_record(data=self.data, hashed_key=self.hashed_idempotency_key)
except IdempotencyItemNotFoundError:
# This code path will only be triggered if the record is removed between save_inprogress and get_record.
logger.debug(
Expand Down Expand Up @@ -180,7 +182,7 @@ def _get_function_response(self):
# We need these nested blocks to preserve function's exception in case the persistence store operation
# also raises an exception
try:
self.persistence_store.delete_record(data=self.data, exception=handler_exception)
self.persistence_store.delete_record(hashed_key=self.hashed_idempotency_key, exception=handler_exception)
except Exception as delete_exception:
raise IdempotencyPersistenceLayerError(
"Failed to delete record from idempotency store"
Expand All @@ -189,7 +191,7 @@ def _get_function_response(self):

else:
try:
self.persistence_store.save_success(data=self.data, result=response)
self.persistence_store.save_success(hashed_key=self.hashed_idempotency_key, hashed_payload=self.hashed_payload, result=response)
except Exception as save_exception:
raise IdempotencyPersistenceLayerError(
"Failed to update record state to success in idempotency store"
Expand Down
30 changes: 15 additions & 15 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def configure(self, config: IdempotencyConfig, function_name: Optional[str] = No
self._cache = LRUDict(max_items=config.local_cache_max_items)
self.hash_function = getattr(hashlib, config.hash_function)

def _get_hashed_idempotency_key(self, data: Dict[str, Any]) -> str:
def get_hashed_idempotency_key(self, data: Dict[str, Any]) -> str:
"""
Extract idempotency key and return a hashed representation

Expand Down Expand Up @@ -189,7 +189,7 @@ def is_missing_idempotency_key(data) -> bool:
return all(x is None for x in data)
return not data

def _get_hashed_payload(self, data: Dict[str, Any]) -> str:
def get_hashed_payload(self, data: Dict[str, Any]) -> str:
"""
Extract payload using validation key jmespath and return a hashed representation

Expand Down Expand Up @@ -245,7 +245,7 @@ def _validate_payload(self, data: Dict[str, Any], data_record: DataRecord) -> No

"""
if self.payload_validation_enabled:
data_hash = self._get_hashed_payload(data=data)
data_hash = self.get_hashed_payload(data=data)
if data_record.payload_hash != data_hash:
raise IdempotencyValidationError("Payload does not match stored record for this event key")

Expand Down Expand Up @@ -300,7 +300,7 @@ def _delete_from_cache(self, idempotency_key: str):
if idempotency_key in self._cache:
del self._cache[idempotency_key]

def save_success(self, data: Dict[str, Any], result: dict) -> None:
def save_success(self, hashed_key: str, hashed_payload: str, result: dict) -> None:
"""
Save record of function's execution completing successfully

Expand All @@ -314,11 +314,11 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None:
response_data = json.dumps(result, cls=Encoder, sort_keys=True)

data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(data=data),
idempotency_key=hashed_key,
status=STATUS_CONSTANTS["COMPLETED"],
expiry_timestamp=self._get_expiry_timestamp(),
response_data=response_data,
payload_hash=self._get_hashed_payload(data=data),
payload_hash=hashed_payload,
)
logger.debug(
f"Function successfully executed. Saving record to persistence store with "
Expand All @@ -328,20 +328,19 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None:

self._save_to_cache(data_record=data_record)

def save_inprogress(self, data: Dict[str, Any]) -> None:
def save_inprogress(self, hashed_key=None, hashed_payload=None) -> None:
"""
Save record of function's execution being in progress

Parameters
----------
data: Dict[str, Any]
Payload
TODO: update
"""
data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(data=data),
idempotency_key=hashed_key,
status=STATUS_CONSTANTS["INPROGRESS"],
expiry_timestamp=self._get_expiry_timestamp(),
payload_hash=self._get_hashed_payload(data=data),
payload_hash=hashed_payload,
)

logger.debug(f"Saving in progress record for idempotency key: {data_record.idempotency_key}")
Expand All @@ -351,7 +350,7 @@ def save_inprogress(self, data: Dict[str, Any]) -> None:

self._put_record(data_record=data_record)

def delete_record(self, data: Dict[str, Any], exception: Exception):
def delete_record(self, hashed_key: str, exception: Exception):
"""
Delete record from the persistence store

Expand All @@ -362,7 +361,7 @@ def delete_record(self, data: Dict[str, Any], exception: Exception):
exception
The exception raised by the function
"""
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(data=data))
data_record = DataRecord(idempotency_key=hashed_key)

logger.debug(
f"Function raised an exception ({type(exception).__name__}). Clearing in progress record in persistence "
Expand All @@ -372,7 +371,7 @@ def delete_record(self, data: Dict[str, Any], exception: Exception):

self._delete_from_cache(idempotency_key=data_record.idempotency_key)

def get_record(self, data: Dict[str, Any]) -> DataRecord:
def get_record(self, data: Dict[str, Any], hashed_key: str) -> DataRecord:
"""
Retrieve idempotency key for data provided, fetch from persistence store, and convert to DataRecord.

Expand All @@ -394,7 +393,7 @@ def get_record(self, data: Dict[str, Any]) -> DataRecord:
Payload doesn't match the stored record for the given idempotency key
"""

idempotency_key = self._get_hashed_idempotency_key(data=data)
idempotency_key = hashed_key

cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
Expand All @@ -409,6 +408,7 @@ def get_record(self, data: Dict[str, Any]) -> DataRecord:
self._validate_payload(data=data, data_record=record)
return record

# TODO these would all need to be updated too...
@abstractmethod
def _get_record(self, idempotency_key) -> DataRecord:
"""
Expand Down
42 changes: 38 additions & 4 deletions tests/functional/idempotency/test_idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,40 @@ def lambda_handler(event, context):
stubber.assert_no_pending_responses()
stubber.deactivate()

@pytest.mark.parametrize("idempotency_config", [{"use_local_cache": False}, {"use_local_cache": True}], indirect=True)
def test_idempotent_lambda_first_execution_event_mutation(
idempotency_config: IdempotencyConfig,
persistence_store: DynamoDBPersistenceLayer,
lambda_apigw_event,
expected_params_update_item,
expected_params_put_item,
lambda_response,
serialized_lambda_response,
deserialized_lambda_response,
hashed_idempotency_key,
lambda_context,
):
"""
Test idempotent decorator where lambda_handler mutates the event
"""

stubber = stub.Stubber(persistence_store.table.meta.client)
ddb_response = {}

stubber.add_response("put_item", ddb_response, expected_params_put_item)
stubber.add_response("update_item", ddb_response, expected_params_update_item)
stubber.activate()

@idempotent(config=idempotency_config, persistence_store=persistence_store)
def lambda_handler(event, context):
event.popitem()
return lambda_response

lambda_handler(lambda_apigw_event, lambda_context)

stubber.assert_no_pending_responses()
stubber.deactivate()


@pytest.mark.parametrize("idempotency_config", [{"use_local_cache": False}, {"use_local_cache": True}], indirect=True)
def test_idempotent_lambda_expired(
Expand Down Expand Up @@ -770,7 +804,7 @@ def test_default_no_raise_on_missing_idempotency_key(
assert "body" in persistence_store.event_key_jmespath

# WHEN getting the hashed idempotency key for an event with no `body` key
hashed_key = persistence_store._get_hashed_idempotency_key({})
hashed_key = persistence_store.get_hashed_idempotency_key({})

# THEN return the hash of None
expected_value = f"test-func.{function_name}#" + md5(json_serialize(None).encode()).hexdigest()
Expand All @@ -791,7 +825,7 @@ def test_raise_on_no_idempotency_key(

# WHEN getting the hashed idempotency key for an event with no `body` key
with pytest.raises(IdempotencyKeyError) as excinfo:
persistence_store._get_hashed_idempotency_key({})
persistence_store.get_hashed_idempotency_key({})

# THEN raise IdempotencyKeyError error
assert "No data found to create a hashed idempotency_key" in str(excinfo.value)
Expand Down Expand Up @@ -821,7 +855,7 @@ def test_jmespath_with_powertools_json(
}

# WHEN calling _get_hashed_idempotency_key
result = persistence_store._get_hashed_idempotency_key(api_gateway_proxy_event)
result = persistence_store.get_hashed_idempotency_key(api_gateway_proxy_event)

# THEN the hashed idempotency key should match the extracted values generated hash
assert result == "test-func.handler#" + persistence_store._generate_hash(expected_value)
Expand All @@ -838,7 +872,7 @@ def test_custom_jmespath_function_overrides_builtin_functions(
with pytest.raises(jmespath.exceptions.UnknownFunctionError, match="Unknown function: powertools_json()"):
# WHEN calling _get_hashed_idempotency_key
# THEN raise unknown function
persistence_store._get_hashed_idempotency_key({})
persistence_store.get_hashed_idempotency_key({})


def test_idempotent_lambda_save_inprogress_error(persistence_store: DynamoDBPersistenceLayer, lambda_context):
Expand Down