From 4c9ef74a163d686133b1e43d927d1cf3b89b2f5b Mon Sep 17 00:00:00 2001 From: Otto Jongerius Date: Tue, 5 Apr 2022 13:26:21 +1200 Subject: [PATCH] fix(idempotency): compute hashes once to guard against inadvert mutations This protects against functions that inadvertently change the incoming event data --- .../utilities/idempotency/base.py | 10 +++-- .../utilities/idempotency/persistence/base.py | 30 ++++++------- .../idempotency/test_idempotency.py | 42 +++++++++++++++++-- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/aws_lambda_powertools/utilities/idempotency/base.py b/aws_lambda_powertools/utilities/idempotency/base.py index dddc36b426d..2d63a7e7db6 100644 --- a/aws_lambda_powertools/utilities/idempotency/base.py +++ b/aws_lambda_powertools/utilities/idempotency/base.py @@ -75,6 +75,8 @@ def __init__( persistence_store.configure(config, self.function.__name__) self.persistence_store = persistence_store + self.hashed_idempotency_key = persistence_store.get_hashed_idempotency_key(self.data) + self.hashed_payload = persistence_store.get_hashed_payload(self.data) def handle(self) -> Any: """ @@ -100,7 +102,7 @@ def _process_idempotency(self): try: # We call save_inprogress first as an optimization for the most common case where no idempotent record # already exists. If it succeeds, there's no need to call get_record. - self.persistence_store.save_inprogress(data=self.data) + self.persistence_store.save_inprogress(hashed_key=self.hashed_idempotency_key, hashed_payload=self.hashed_payload) except IdempotencyKeyError: raise except IdempotencyItemAlreadyExistsError: @@ -122,7 +124,7 @@ def _get_idempotency_record(self) -> DataRecord: """ try: - data_record = self.persistence_store.get_record(data=self.data) + data_record = self.persistence_store.get_record(data=self.data, hashed_key=self.hashed_idempotency_key) except IdempotencyItemNotFoundError: # This code path will only be triggered if the record is removed between save_inprogress and get_record. logger.debug( @@ -180,7 +182,7 @@ def _get_function_response(self): # We need these nested blocks to preserve function's exception in case the persistence store operation # also raises an exception try: - self.persistence_store.delete_record(data=self.data, exception=handler_exception) + self.persistence_store.delete_record(hashed_key=self.hashed_idempotency_key, exception=handler_exception) except Exception as delete_exception: raise IdempotencyPersistenceLayerError( "Failed to delete record from idempotency store" @@ -189,7 +191,7 @@ def _get_function_response(self): else: try: - self.persistence_store.save_success(data=self.data, result=response) + self.persistence_store.save_success(hashed_key=self.hashed_idempotency_key, hashed_payload=self.hashed_payload, result=response) except Exception as save_exception: raise IdempotencyPersistenceLayerError( "Failed to update record state to success in idempotency store" diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index e6ffea10de8..4d7cf40c859 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -157,7 +157,7 @@ def configure(self, config: IdempotencyConfig, function_name: Optional[str] = No self._cache = LRUDict(max_items=config.local_cache_max_items) self.hash_function = getattr(hashlib, config.hash_function) - def _get_hashed_idempotency_key(self, data: Dict[str, Any]) -> str: + def get_hashed_idempotency_key(self, data: Dict[str, Any]) -> str: """ Extract idempotency key and return a hashed representation @@ -189,7 +189,7 @@ def is_missing_idempotency_key(data) -> bool: return all(x is None for x in data) return not data - def _get_hashed_payload(self, data: Dict[str, Any]) -> str: + def get_hashed_payload(self, data: Dict[str, Any]) -> str: """ Extract payload using validation key jmespath and return a hashed representation @@ -245,7 +245,7 @@ def _validate_payload(self, data: Dict[str, Any], data_record: DataRecord) -> No """ if self.payload_validation_enabled: - data_hash = self._get_hashed_payload(data=data) + data_hash = self.get_hashed_payload(data=data) if data_record.payload_hash != data_hash: raise IdempotencyValidationError("Payload does not match stored record for this event key") @@ -300,7 +300,7 @@ def _delete_from_cache(self, idempotency_key: str): if idempotency_key in self._cache: del self._cache[idempotency_key] - def save_success(self, data: Dict[str, Any], result: dict) -> None: + def save_success(self, hashed_key: str, hashed_payload: str, result: dict) -> None: """ Save record of function's execution completing successfully @@ -314,11 +314,11 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None: response_data = json.dumps(result, cls=Encoder, sort_keys=True) data_record = DataRecord( - idempotency_key=self._get_hashed_idempotency_key(data=data), + idempotency_key=hashed_key, status=STATUS_CONSTANTS["COMPLETED"], expiry_timestamp=self._get_expiry_timestamp(), response_data=response_data, - payload_hash=self._get_hashed_payload(data=data), + payload_hash=hashed_payload, ) logger.debug( f"Function successfully executed. Saving record to persistence store with " @@ -328,20 +328,19 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None: self._save_to_cache(data_record=data_record) - def save_inprogress(self, data: Dict[str, Any]) -> None: + def save_inprogress(self, hashed_key=None, hashed_payload=None) -> None: """ Save record of function's execution being in progress Parameters ---------- - data: Dict[str, Any] - Payload + TODO: update """ data_record = DataRecord( - idempotency_key=self._get_hashed_idempotency_key(data=data), + idempotency_key=hashed_key, status=STATUS_CONSTANTS["INPROGRESS"], expiry_timestamp=self._get_expiry_timestamp(), - payload_hash=self._get_hashed_payload(data=data), + payload_hash=hashed_payload, ) logger.debug(f"Saving in progress record for idempotency key: {data_record.idempotency_key}") @@ -351,7 +350,7 @@ def save_inprogress(self, data: Dict[str, Any]) -> None: self._put_record(data_record=data_record) - def delete_record(self, data: Dict[str, Any], exception: Exception): + def delete_record(self, hashed_key: str, exception: Exception): """ Delete record from the persistence store @@ -362,7 +361,7 @@ def delete_record(self, data: Dict[str, Any], exception: Exception): exception The exception raised by the function """ - data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(data=data)) + data_record = DataRecord(idempotency_key=hashed_key) logger.debug( f"Function raised an exception ({type(exception).__name__}). Clearing in progress record in persistence " @@ -372,7 +371,7 @@ def delete_record(self, data: Dict[str, Any], exception: Exception): self._delete_from_cache(idempotency_key=data_record.idempotency_key) - def get_record(self, data: Dict[str, Any]) -> DataRecord: + def get_record(self, data: Dict[str, Any], hashed_key: str) -> DataRecord: """ Retrieve idempotency key for data provided, fetch from persistence store, and convert to DataRecord. @@ -394,7 +393,7 @@ def get_record(self, data: Dict[str, Any]) -> DataRecord: Payload doesn't match the stored record for the given idempotency key """ - idempotency_key = self._get_hashed_idempotency_key(data=data) + idempotency_key = hashed_key cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key) if cached_record: @@ -409,6 +408,7 @@ def get_record(self, data: Dict[str, Any]) -> DataRecord: self._validate_payload(data=data, data_record=record) return record + # TODO these would all need to be updated too... @abstractmethod def _get_record(self, idempotency_key) -> DataRecord: """ diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index 5b76cda0475..45961b10987 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -274,6 +274,40 @@ def lambda_handler(event, context): stubber.assert_no_pending_responses() stubber.deactivate() +@pytest.mark.parametrize("idempotency_config", [{"use_local_cache": False}, {"use_local_cache": True}], indirect=True) +def test_idempotent_lambda_first_execution_event_mutation( + idempotency_config: IdempotencyConfig, + persistence_store: DynamoDBPersistenceLayer, + lambda_apigw_event, + expected_params_update_item, + expected_params_put_item, + lambda_response, + serialized_lambda_response, + deserialized_lambda_response, + hashed_idempotency_key, + lambda_context, +): + """ + Test idempotent decorator where lambda_handler mutates the event + """ + + stubber = stub.Stubber(persistence_store.table.meta.client) + ddb_response = {} + + stubber.add_response("put_item", ddb_response, expected_params_put_item) + stubber.add_response("update_item", ddb_response, expected_params_update_item) + stubber.activate() + + @idempotent(config=idempotency_config, persistence_store=persistence_store) + def lambda_handler(event, context): + event.popitem() + return lambda_response + + lambda_handler(lambda_apigw_event, lambda_context) + + stubber.assert_no_pending_responses() + stubber.deactivate() + @pytest.mark.parametrize("idempotency_config", [{"use_local_cache": False}, {"use_local_cache": True}], indirect=True) def test_idempotent_lambda_expired( @@ -770,7 +804,7 @@ def test_default_no_raise_on_missing_idempotency_key( assert "body" in persistence_store.event_key_jmespath # WHEN getting the hashed idempotency key for an event with no `body` key - hashed_key = persistence_store._get_hashed_idempotency_key({}) + hashed_key = persistence_store.get_hashed_idempotency_key({}) # THEN return the hash of None expected_value = f"test-func.{function_name}#" + md5(json_serialize(None).encode()).hexdigest() @@ -791,7 +825,7 @@ def test_raise_on_no_idempotency_key( # WHEN getting the hashed idempotency key for an event with no `body` key with pytest.raises(IdempotencyKeyError) as excinfo: - persistence_store._get_hashed_idempotency_key({}) + persistence_store.get_hashed_idempotency_key({}) # THEN raise IdempotencyKeyError error assert "No data found to create a hashed idempotency_key" in str(excinfo.value) @@ -821,7 +855,7 @@ def test_jmespath_with_powertools_json( } # WHEN calling _get_hashed_idempotency_key - result = persistence_store._get_hashed_idempotency_key(api_gateway_proxy_event) + result = persistence_store.get_hashed_idempotency_key(api_gateway_proxy_event) # THEN the hashed idempotency key should match the extracted values generated hash assert result == "test-func.handler#" + persistence_store._generate_hash(expected_value) @@ -838,7 +872,7 @@ def test_custom_jmespath_function_overrides_builtin_functions( with pytest.raises(jmespath.exceptions.UnknownFunctionError, match="Unknown function: powertools_json()"): # WHEN calling _get_hashed_idempotency_key # THEN raise unknown function - persistence_store._get_hashed_idempotency_key({}) + persistence_store.get_hashed_idempotency_key({}) def test_idempotent_lambda_save_inprogress_error(persistence_store: DynamoDBPersistenceLayer, lambda_context):